Search code examples
f#compiler-constructionlisp

How to avoid stack overflow during CPS conversion?


I'm writing a transformation from Scheme subset to CPS language. It is implemented in F#. On big input programs conversion fails by stack overflow.

I'm using some sort of algorithm described in the paper Compiling with Continuations. I've tried to increase maximum stack size of the working thread up to 50 MB, then it works.

Maybe there some way to modify the algorithm, so that I won't need to tune stack size?

For example, the algorithm transforms

(foo (bar 1) (bar 2))

to

(let ((c1 (cont (r1)
           (let ((c2 (cont (r2)
                  (foo halt r1 r2))))
            (bar c2 2)))))
 (bar c1 1))

where halt is a final continuation which finishes the program.


Solution

  • Maybe your actual problems has simple solutions to avoid heavy stack consumption, so please don't mind adding details. However, without more knowledge about your particular code, here is a general approach to reduce the stack consumption in a recursive programs, based on trampolines and continuations.

    Walker

    Here is a typical recursive function that is not trivially tail-recursive, written in Common Lisp because I don't know F#:

    (defun walk (form transform join)
      (typecase form
        (cons (funcall join
                       (walk (car form) transform join)
                       (walk (cdr form) transform join)))
        (t (funcall transform form))))
    

    The code is however quite simple, hopefully, and walks a tree made of cons cells:

    1. if the form is a cons-cell, recursively walk on the car (resp. cdr) and join the results
    2. Otherwise, apply a transform on the value

    For example:

    (walk '(a (b c d) 3 2 (a 2 1) 0)
          (lambda (u) (and (numberp u) u))
          (lambda (a b) (if a (cons a b) (or a b))))
    
    => (3 2 (2 1) 0)
    

    The code walks the form, and retain only numbers, but preserves (non-empty) nesting.

    Calling trace on walk with the above example shows a maximal depth of 8 nested calls.

    Continuations and trampoline

    Here is an adapted version, called walk/then, that walks a form as previously, and when a result is available, calls then on it. Here then is a continuation.

    The function also returns a thunk, i.e. a parameterless closure. What happens is that when we return the closure, the stack is unwound, and when we apply the thunk it will start from a fresh stack, but having advanced in the computation (I usually picture someone walking up an escalator that goes down). The fact that we return a thunk to reduce the number of stack frames is part of the trampoline.

    The then function takes a value, namely the result that the current walk eventually will return. The result is thus passed down the stack, and what is returned at each step is a thunk function.

    Nesting continuations allows to capture the complex behaviour of transform/join, by pushing the remaining parts of the computation in nested continuations.

    (defun walk/then (form transform join then)
      (typecase form
        (cons (lambda ()
                (walk/then (car form) transform join
                           (lambda (v)
                             (walk/then (cdr form) transform join
                                        (lambda (w)
                                          (funcall then (funcall join v w))))))))
        (t (funcall then (funcall transform form)))))
    

    For example, (walk/then (car form) transform join (lambda (v) ...)) reads as follows: walk the car of form with arguments transform and join, and eventually call (lambda (v) ...) on the result; namely, walk down the cdr, and then join both results; eventually, call the input then on the joined result.

    What is missing is a way to continually call the returned thunk until exhaustion; here is it with a loop, but this could easily be a tail-recursive function:

    (loop for res = 
         (walk/then '(a (b c d) 3 2 (a 2 1) 0)
                    (lambda (u) (and (numberp u) u))
                    (lambda (a b) (if a (cons a b) (or a b)))
                    #'identity)
       then (typecase res (function (funcall res)) (t res))
       while (functionp res)
       finally (return res))
    

    The above returns (3 2 (2 1) 0), and the depth of the trace never goes over 2 when tracing walk/then.

    See Eli Bendersky's article for another take at this, in Python.