I'm trying to understand how continuations work, I have this example that I came across in the book, Real World Functional Programming by Tomas Petricek with Jon Skeet. But this really has got my head spinning so I must ask for some detailed help..
type IntTree =
| Leaf of int
| Node of IntTree * IntTree
let rec sumTreeCont tree cont =
match tree with
| Leaf(n) -> cont(n)
| Node(left, right) ->
sumTreeCont left (fun leftSum ->
sumTreeCont right (fun rightSum ->
cont(leftSum + rightSum)))
Okay here's what I have been able to figure out myself... In the second branch we first process the left side of the node and pass a lambda. This lambda will instantiate a closure class with two fields, right: IntTree
and cont: (int -> 'a)
which will be invoked by the base case. But then it also seems that the "inner lambda" captures leftSum
but I don't quite understand how it all fits together, I have to admit that I get a bit dizzy when I try to figure this out.
I think Christian's answer is a good one - continuation passing style is really just a (not so) simple mechanical transformation that you do on the original source code. This might be easier to see when you do it step by step:
1) Start with the original code (here, I change the code to only do one operation per line):
let rec sumTree tree =
match tree with
| Leaf(n) -> n
| Node(left, right) ->
let leftSum = sumTree left
let rightSum = sumTree right
leftSum + rightSum
2) Add continuation parameter and call it instead of returning the result (this is still not tail-recursive). To make this type-check, I added continuation fun x -> x
to both sub-calls so that they just return the sum as the result:
let rec sumTree tree cont =
match tree with
| Leaf(n) -> cont n
| Node(left, right) ->
let leftSum = sumTree left (fun x -> x)
let rightSum = sumTree right (fun x -> x)
cont (leftSum + rightSum)
3) Now, let's change the first recursive call to use continuation passing style - lift the rest of the body into the continuation:
let rec sumTree tree cont =
match tree with
| Leaf(n) -> cont n
| Node(left, right) ->
sumTree left (fun leftSum ->
let rightSum = sumTree right (fun x -> x)
cont (leftSum + rightSum) )
4) And repeat the same thing for the second recursive call:
let rec sumTree tree cont =
match tree with
| Leaf(n) -> cont n
| Node(left, right) ->
sumTree left (fun leftSum ->
sumTree right (fun rightSum ->
cont (leftSum + rightSum) ))