Search code examples
haskelltypestreetype-inferencefold

Understanding foldTree function's type derivation


Looking at this definition appearing in Data.Tree:

foldTree :: (a -> [b] -> b) -> Tree a -> b
foldTree f = go where
    go (Node x ts) = f x (map go ts)

My specific question is: as go name appears in the right-hand side of the equation in (map go ts), how is the type of the function

(a -> [b] -> b)

being inferred ?

For example, having this line of code:

foldTree (:) (Node 1 [Node 2 []])

an instantiating the definition:

foldTree (:) = go where
    go (Node 1 [Node 2 []]) = (:) 1 (map go [Node 2 []])

(:) 1 (map go [Node 2 []]) is not fully evaluated, so I just see (:) 1 having the type Num a => [a] -> [a]. However, there is one gap missing and, in order to fill it, the recursion should be completed. So, there seems to be some circularity for calculating the type.

Any insights are much appreciated.


Solution

  • Haskell's type inference is very clever! I can't tell you how this actually gets inferred, but let's walk through how it might be. The reality is probably not too far off. The type signature is actually not required in this case.

    foldTree f = go where
        go (Node x ts) = f x (map go ts)
    

    foldTree is defined to take an argument, and go is defined to take an argument, so we know right from the start that these are functions.

    foldTree :: _a -> _b
    foldTree f = go where
        go :: _c -> _d
        go (Node x ts) = f x (map go ts)
    

    Now we see that f is called with two arguments, so it must actually be a function of (at least) two arguments.

    foldTree :: (_x -> _y -> _z) -> _b
    foldTree f = go where
        go :: _c -> _d
        go (Node x ts) = f x (map go ts)
    

    Since foldTree f = go, and go :: _c -> _d, the result type _b must actually be _c -> _d *:

    foldTree :: (_x -> _y -> _z) -> _c -> _d
    foldTree f = go where
        go :: _c -> _d
        go (Node x ts) = f x (map go ts)
    

    The second argument passed to f (of type _y) is map go ts. Since go :: _c -> _d, _y must be [_d]

    foldTree :: (_x -> [_d] -> _z) -> _c -> _d
    foldTree f = go where
        go :: _c -> _d
        go (Node x ts) = f x (map go ts)
    

    go matches its argument against Node x ts, and Node is a data constructor for Tree, so go's argument (_c) must be a Tree.

    foldTree :: (_x -> [_d] -> _z) -> Tree _p -> _d
    foldTree f = go where
        go :: Tree _p -> _d
        go (Node x ts) = f x (map go ts)
    

    The first field of the Node constructor is passed as the first argument of f, so _x and _p must be the same:

    foldTree :: (_x -> [_d] -> _z) -> Tree _x -> _d
    foldTree f = go where
        go :: Tree _x -> _d
        go (Node x ts) = f x (map go ts)
    

    Since go _ is defined as f _ _, they must have results of the same type, so _z is _d:

    foldTree :: (_x -> [_d] -> _d) -> Tree _x -> _d
    foldTree f = go where
        go :: Tree _x -> _d
        go (Node x ts) = f x (map go ts)
    

    Whew. Now the compiler checks to make sure these types work out (which they do), and it "generalizes" them from "metavariables" (variables that mean the inference engine doesn't know what type they represent) to quantified type variables (variables that are definitely polymorphic), and it gets

    foldTree :: forall a b. (a -> [b] -> b) -> Tree a -> b
    foldTree f = go where
        go :: Tree a -> b
        go (Node x ts) = f x (map go ts)
    

    The reality is a bit more complicated, but this should give you the gist.

    [*] This step is a bit of a cheat. I'm ignoring a feature called "let generalization", which isn't needed in this context and which is actually disabled by several language extensions in GHC Haskell.