Let the compiler do the work for you!

Jan van Brügge - Aug 1 '19 - - Dev Community

Recently I came across a little programming puzzle, the task was to take a binary search tree and return a new tree, where every node is replaced by the sum of the nodes to the right. So given this tree:

    5
   / \
  2    7
 /    / \
1    6   8
      \
       6
Enter fullscreen mode Exit fullscreen mode

We start on the rightmost number and set it to zero (because the sum of nothing is still nothing). The seven above it comes next and the current sum is 8 because there is only one node to the right. After this comes the "second" 6, so the child of the other 6, again, because it is further to the right. In general, the order is always right subtree, self, left subtree. In the end we end up with this tree, check if you understand where every number comes from:

    27
   /  \
  32   8
 /    / \
34   21  0
      \
       15
Enter fullscreen mode Exit fullscreen mode

Implementing this algorithm in Java

For the following implementations, we will only use stuff from the standard library.

First, we define ourselves a tree:

public class Node {
    Node left;
    int value
    Node right;

    public Node(Node l, int v, Node r) {
        this.left = l;
        this.value = v;
        this.right = r;
    }
}
Enter fullscreen mode Exit fullscreen mode

Now, as always with recursive data structures like lists and trees, we will define the algorithm as a recursive function. We will need to keep track of the current sum and just build up a new tree with the new data.

public class Pair {
    Node n;
    int v;

    public Pair(Node n, int v) {
        this.n = n;
        this.v = v;
    }
}

public class Node {
    Node left;
    int value;
    Node right;

    public Node(Node l, int v, Node r) { /* ... */ }

    public Pair solve(int currentSum) {
        // Store the current sum in here, as fallback if right is null
        Pair rightResult = new Pair(null, currentSum);
        // First, go to the right subtree, if it exists
        if(this.right != null) {
            rightResult = this.right.solve(currentSum);
        }
        int sum = rightResult.v + this.value;

        // Again, save the sum as fallback
        Pair leftResult = new Pair(null, sum);
        if(this.left != null) {
            leftResult = this.left.solve(sum);
        }

        // Finally create a new node (to replace self)
        Node newSelf = new Node(leftResult.n, rightResult.v, rightResult.n);
       // And return it together with the sum
       return new Pair(newSelf, leftResult.v);
    }
}    
Enter fullscreen mode Exit fullscreen mode

Now, we just need a main function to run this:

public class Main {
    static Node testTree = new Node(
        new Node(
            new Node(null, 1, null),
            2,
            null
        ),
        5,
        new Node(
            new Node(
                null,
                6,
                new Node(null, 6, null)
            ),
            7,
            new Node(null, 8, null)
        )
    );

    public static void main(String[] args) {
        Pair result = testTree.solve(0);
        System.out.println(result.n);
    }
}
Enter fullscreen mode Exit fullscreen mode

To see our result, we also need a toString method on the Node:

public class Node {
    Node left;
    int value;
    Node right;

    public Node(Node l, int v, Node r) { /* ... */ }

    public Pair solve(int currentSum) { /* ... */ }

    public String toString() {
        String leftTree = left == null ? " " : left.toString();
        String rightTree = right == null ? " " : right.toString();
        return "Node(" + leftTree + ", " + value + ", " + rightTree + ")";
    }
}
Enter fullscreen mode Exit fullscreen mode

If you run the code now, you will see

Node(Node(Node( , 34,  ), 32,  ), 27, Node(Node( , 21, Node( , 15,  )), 8, Node( , 0,  )))
Enter fullscreen mode Exit fullscreen mode

which is exactly the tree we are looking for!

Implementing this in Haskell

Now you may ask, what does this have to do with the title? The compiler does not write the code for us here. This is because the Java compiler is merely a "checking" compiler (as I call it). It complains if you mismatch types or make syntax errors, but it does not do work for you, it is basically like the teacher that checks your answer afterwards.

The Haskell compiler is different. And is so for several reasons. First, Haskell infer types, so you do not have to annotate them everywhere and errors also say which type is expected where. The other reason is that Haskell focuses a lot more on fundamental abstractions than other commonly used languages.

One of this fundamental abstractions is a Functor. If your datatype is a functor, you can map a function over it. Nothing more, nothing less. For example, an array or list is a Functor where you apply a function for each element (you may know this from Java streams or JavaScript Array.map). In Haskell, this works over any datatype that "contains" other data. So let's define such a datatype - note that we leave the type of data in the tree abstract, in the Java example it was int:

module Main where

data Tree a = Leaf | Node (Tree a) a (Tree a)
Enter fullscreen mode Exit fullscreen mode

This declaration is fundamentally the same as the Node class in the Java example. Just instead of null we use Leaf to signal empty children and we did not name the data left, value, right, but just put them in that order.

So, now back to functor: The compiler is able to automatically generate the code you would need to implement this mapping. All you need to do is to enable that feature and use the deriving clause.

{-# LANGUAGE DeriveFunctor #-}
module Main where

data Tree a = Leaf | Node (Tree a) a (Tree a)
            deriving Functor
Enter fullscreen mode Exit fullscreen mode

This would already allow us to increment all nodes in the tree by one for example:

incrementNodes :: Tree Int -> Tree Int
incrementNodes tree = fmap (+1) tree
Enter fullscreen mode Exit fullscreen mode

But now we would like to map and collect at the same time. We solved the first half, let's do the second half. For collapsing a structure down to a singular value, there is the type class Foldable that requires your data to be a Functor already. Again, array and lists are foldable, and it works just like collect in Java 8 and Array.reduce in JavaScript. And again, this can be automatically generated by the compiler for you:

{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveFoldable #-}
module Main where

data Tree a = Leaf | Node (Tree a) a (Tree a)
            deriving (Functor, Foldable)
Enter fullscreen mode Exit fullscreen mode

Now we could calculate the sum of all nodes for example:

sumNodes :: Tree Int -> Int
sumNodes tree = foldr (+) 0 tree
Enter fullscreen mode Exit fullscreen mode

Now to the last step: mapping and folding in one. For this we need another abstraction - Traversable. This class allows to map and also simultaneously keep track of some "side effects", but needs your data to be Foldable already. For example we can keep track of some local state - the current sum. And like the classes before, the compiler can generate it automatically for us:

{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveTraversable #-}
module Main where

data Tree a = Leaf | Node (Tree a) a (Tree a)
            deriving (Functor, Foldable, Traversable)
Enter fullscreen mode Exit fullscreen mode

With all this in place, our final solution is pretty simple:

solve :: Tree Int -> (Int, Tree Int)
solve tree = mapAccumR (\a b -> (a + b, a)) 0 tree
Enter fullscreen mode Exit fullscreen mode

mapAccumR is a function from the Haskell standard library that is defined like this:

mapAccumR :: Traversable t => (a -> b -> (a, c)) -> a -> t b -> (a, t c)
mapAccumR fun init x = runStateR (traverse (StateR . flip fun) x) init
Enter fullscreen mode Exit fullscreen mode

This exact definition is not that important, just see that in the inner parenthesis it uses StateR to make your function behave like a side effect. Then it uses traverse to combine the stateful effects in the right order and runStateR then executes them one by one.

Java required us to write a toString method by hand. In Haskell we can again use the compiler for this by deriving Show. So with a main method to make everything runnable, our complete code for the puzzle is just:

{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveTraversable #-}
module Main where

data Tree a = Leaf | Node (Tree a) a (Tree a)
            deriving (Show, Functor, Foldable, Traversable)

solve :: Tree Int -> (Int, Tree Int)
solve tree = mapAccumR (\a b -> (a + b, a)) 0 tree

testTree :: Tree Int
testTree =
    Node
        (Node
            (Node Leaf 1 Leaf)
            2
            Leaf
        )
        5
        (Node
            (Node
                Leaf
                6
                (Node Leaf 6 Leaf)
            )
            7
            (Node Leaf 8 Leaf)
        )

main :: IO ()
main = do
    let (sum, tree) = solve testTree
    print tree
Enter fullscreen mode Exit fullscreen mode

Conclusion

As you can see if the compiler has your back, you can severely cut down the number of lines of code. More code always means more bugs, so everything we do not have to write ourselves is good.

. . . . . . . . . . . . . .
Terabox Video Player