Search code examples
functional-programmingf#continuation-passing

Simplify multiway tree traversal with continuation passing style


I am fascinated by the approach used in this blog post to traverse a rose tree a.k.a multiway tree a.k.a n-ary tree using CPS.

Here is my code, with type annotations removed and names changed, which I did while trying to understand the technique:

type 'a Tree = Node of 'a * 'a Tree list | Leaf of 'a

let rec reduce recCalls cont =
    match recCalls with
    | [] -> [] |> cont 
    | findMaxCall :: pendingCalls ->
        findMaxCall (fun maxAtNode ->
                    reduce pendingCalls (fun maxVals -> maxAtNode :: maxVals |> cont))
        
let findMaxOf (roseTree : int Tree) =
    let rec findMax tr cont =
        match tr with
        | Leaf i -> i |> cont
        | Node (i, chld) ->
            let recCalls = chld |> List.map findMax 
            reduce recCalls (fun maxVals -> List.max (i :: maxVals) |> cont)
    findMax roseTree id 
    
// test it
let FindMaxOfRoseTree =
    let t = Node (1, [ Leaf 2; Leaf 3 ])
    let maxOf = findMaxOf t //will be 3
    maxOf

My problem is, I find this approach hard to follow. The mutual recursion (assuming that's the right term) is really clever to my simpleton brain, but I get lost while trying to understand how it works, even when using simple examples and writing down steps manually etc.

I am in need of using CPS with Rose trees, and I'll be doing the kind of traversals that require a CPS, because just like this example, computing results based on my my tree nodes require that children of the nodes are computed first. In any case, I do like CPS and I'd like to improve my understanding of it.

So my question is: Is there an alternative way of implementing CPS on rose trees which I may manage to better follow understand? Is there a way to refactor the above code which may make it easier to follow (eliminating the mutual recursion?)

If there is a name for the above approach, or some resources/books I can read to understand it better, hints are also most welcome.


Solution

  • CPS can definitely be confusing, but there are some things you can do to simplify this code:

    • Remove the Leaf case from your type because it's redundant. A leaf is just a Node with an empty list of children.
    • Separate general-purpose CPS logic from logic that's specific to rose trees.
    • Use the continuation monad to simplify CPS code.

    First, let's define the continuation monad:

    type ContinuationMonad() =
        member __.Bind(m, f) = fun c -> m (fun a -> f a c)
        member __.Return(x) = fun k -> k x
    
    let cont = ContinuationMonad()
    

    Using this builder, we can define a general-purpose CPS reduce function that combines a list of "incomplete" computations into a single incomplete computation (where an incomplete computation is any function that takes a continuation of type 't -> 'u and uses it to produce a value of type 'u).

    let rec reduce fs =
        cont {
            match fs with
            | [] -> return []
            | head :: tail ->
                let! result = head
                let! results = reduce tail
                return result :: results
        }
    

    I think this is certainly clearer, but it might seem like magic. The key to understanding let! x = f for this builder is that x is the value passed to f's implied continuation. This allows us to get rid of lots of lambdas and nested parens.

    Now we're ready to work with rose trees. Here's the simplified type definition:

    type 'a Tree = Node of 'a * 'a Tree list
    
    let leaf a = Node (a, [])
    

    Finding the maximum value in a tree now looks like this:

    let rec findMax (Node (i, chld)) =
        cont {
            let! maxVals = chld |> List.map findMax |> reduce
            return List.max (i :: maxVals)
        }
    

    Note that there's no mutual recursion here. Both reduce and findMax are self-recursive, but reduce doesn't call findMax and doesn't know anything about rose trees.

    You can test the refactored code like this:

    let t = Node (1, [ leaf 2; leaf 3 ])
    findMax t (printfn "%A")   // will be 3
    

    For convenience, I created a gist containing all the code.