Search code examples
optimizationf#functional-programmingtail-recursioncontinuations

Does the continuation + tail recursion trick actually trade stack space for heap space?


There is this CPS trick in functional programming to take a non-tail-recursive function and rewrite it in continuation passing style (CPS), thus trivially making it tail-recursive. A lot of questions actually cover this, like

Take some example

let rec count n = 
    if n = 0
      then 0
      else 1 + count (n - 1)

let rec countCPS n cont =
    if n = 0
      then cont 0
      else countCPS (n - 1) (fun ret -> cont (ret + 1))

The first version of count will accumulate stack frames in each recursive call, producing a stack overflow at around n = 60000 on my computer.

The idea of the CPS trick is that the countCPS implementation is tail-recursive, so that the computation

let f = countCPS 60000

will actually be optimized to run as a loop and work without problems. Instead of stack frames, the continuation to be run will accumulate in every step, but this is an honest object on the heap where memory doesn't cause problems. So the CPS style is said to trade stack space for heap space. But I'm skeptical it does even do that.

Here's why: Evaluating the computation by actually running the continuation as countCPS 60000 (fun x -> x) blows my stack! Each call

countCPS (n - 1) (fun ret -> cont (ret + 1))

generates a new continuation closure from the old one and running it involves one function application. So when evaluating countCPS 60000 (fun x -> x), we invoke a nested sequence of 60000 closures, and even though their data lies on the heap, we have function applications nontheless, so there are the stack frames again.

Let's dive into the generated code, disassembled into C#

For countCPS, we get

public static a countCPS<a>(int n, FSharpFunc<int, a> cont)
{
    while (n != 0)
    {
        int arg_1B_0 = n - 1;
        cont = new Program<a>.countCPS@10(cont);
        n = arg_1B_0;
    }
    return cont.Invoke(0);
}

There we go, tail recursion actually got optimized away. However, the closure class looks like

internal class countCPS@10<a> : FSharpFunc<int, a>
{
    public FSharpFunc<int, a> cont;

    internal countCPS@10(FSharpFunc<int, a> cont)
    {
        this.cont = cont;
    }

    public override a Invoke(int ret)
    {
        return this.cont.Invoke(ret + 1);
    }
}

So running the outermost closure will cause it to .Invoke its child closure, then it's child closure again and again ... We really have 60000 nested function calls again.

So I don't see how the continuation trick is actually able to do what's being advertized.

Now we could argue that the this.cont.Invoke is sort of a tail call again, so it doesn't need a stack frame. Does .NET perform this kind of optimization? What about more complicated examples like

let rec fib_cps n k = match n with
  | 0 | 1 -> k 1
  | n -> fib_cps (n-1) (fun a -> fib_cps (n-2) (fun b -> k (a+b)))

At least we would have to argue why we can optimize away the nested function calls captured in the continuation.


Edit

    interface FSharpFunc<A, B>
    {
        B Invoke(A arg);
    }

    class Closure<A> : FSharpFunc<int, A>
    {
        public FSharpFunc<int, A> cont;

        public Closure(FSharpFunc<int, A> cont)
        {
            this.cont = cont;
        }

        public A Invoke(int arg)
        {
            return cont.Invoke(arg + 1);
        }
    }

    class Identity<A> : FSharpFunc<A, A>
    {
        public A Invoke(A arg)
        {
            return arg;
        }
    }
    static void Main(string[] args)
    {
        FSharpFunc<int, int> computation = new Identity<int>();

        for(int n = 10; n > 0; --n)
            computation = new Closure<int>(computation);

        Console.WriteLine(computation.Invoke(0));
    }

To be even more precise, we model the closure that the CPS style-function builds up in C#.

Clearly, the data lie on the heap. However, evaluating computation.Invoke(0) results in a cascade of nested Invokes to the child closures. Just put a break point on Identity.Invoke and look at the stack trace! So how does the built-up computation trade stack- for heap space if it in fact heavily uses both?


Solution

  • There is a number of concepts here.

    For a tail-recursive function, the compiler can optimize it into a loop and so it does not need any stack or heap space. You can rewrite your count function into a simple tail-recursive function by writing:

    let rec count acc n = 
       if n = 0
          then acc
          else count (acc + 1) (n - 1)
    

    This will be compiled into a method with a while loop that makes no recursive calls.

    Continuations are generally needed when a function cannot be written as tail-recursive. Then you need to keep some state either on the stack or on the heap. Ignoring the fact that fib can be written more efficiently, the naïve recursive implementation would be:

    let fib n = 
      if n <= 1 then 1
      else (fib (n-1)) + (fib (n-2))
    

    This needs stack space to remember what needs to happen after the first recursive call returns the result (we then need to call the other recursive call and add the results). Using continuations, you can turn this into heap-allocated functions:

    let fib n cont = 
      if n <= 1 then cont 1
      else fib (n-1) (fun r1 -> 
             fib (n-2) (fun r2 -> cont (r1 + r2))
    

    This allocates one continuation (function value) for each recursive call, but it is tail-recursive so it will not exhaust the available stack space.