Search code examples
lambdaf#continuations

Processing a tree in F# using continuations


I'm trying to understand how continuations work, I have this example that I came across in the book, Real World Functional Programming by Tomas Petricek with Jon Skeet. But this really has got my head spinning so I must ask for some detailed help..

type IntTree = 
    | Leaf of int
    | Node of IntTree * IntTree

let rec sumTreeCont tree cont =
  match tree with
  | Leaf(n) -> cont(n)
  | Node(left, right) -> 
      sumTreeCont left (fun leftSum ->
        sumTreeCont right (fun rightSum ->
          cont(leftSum + rightSum)))

Okay here's what I have been able to figure out myself... In the second branch we first process the left side of the node and pass a lambda. This lambda will instantiate a closure class with two fields, right: IntTree and cont: (int -> 'a) which will be invoked by the base case. But then it also seems that the "inner lambda" captures leftSum but I don't quite understand how it all fits together, I have to admit that I get a bit dizzy when I try to figure this out.


Solution

  • I think Christian's answer is a good one - continuation passing style is really just a (not so) simple mechanical transformation that you do on the original source code. This might be easier to see when you do it step by step:

    1) Start with the original code (here, I change the code to only do one operation per line):

    let rec sumTree tree =
       match tree with
       | Leaf(n) -> n
       | Node(left, right) -> 
           let leftSum = sumTree left
           let rightSum = sumTree right
           leftSum + rightSum
    

    2) Add continuation parameter and call it instead of returning the result (this is still not tail-recursive). To make this type-check, I added continuation fun x -> x to both sub-calls so that they just return the sum as the result:

    let rec sumTree tree cont =
       match tree with
       | Leaf(n) -> cont n
       | Node(left, right) -> 
           let leftSum = sumTree left (fun x -> x)
           let rightSum = sumTree right (fun x -> x)
           cont (leftSum + rightSum)
    

    3) Now, let's change the first recursive call to use continuation passing style - lift the rest of the body into the continuation:

    let rec sumTree tree cont =
       match tree with
       | Leaf(n) -> cont n
       | Node(left, right) -> 
           sumTree left (fun leftSum ->
             let rightSum = sumTree right (fun x -> x)
             cont (leftSum + rightSum) )
    

    4) And repeat the same thing for the second recursive call:

    let rec sumTree tree cont =
       match tree with
       | Leaf(n) -> cont n
       | Node(left, right) -> 
           sumTree left (fun leftSum ->
             sumTree right (fun rightSum -> 
               cont (leftSum + rightSum) ))