Search code examples
recursioncommon-lisptail-call-optimization

Is there a tco pattern with two accumulating variables?


Just for fun (Project Euler #65) I want to implement the formula

n_k = a_k*n_k-1 + n_k-2

in an efficient way. a_k is either 1 or (* 2 (/ k 3)), depending on k.

I started with a recursive solution:

(defun numerator-of-convergence-for-e-rec (k)
  "Returns the Nth numerator of convergence for Euler's number e."
  (cond ((or (minusp k)) (zerop k) 0)
        ((= 1 k) 2)
        ((= 2 k) 3)
        ((zerop (mod k 3)) (+ (* 2 (/ k 3) (numerator-of-convergence-for-e-rec (1- k)))
                              (numerator-of-convergence-for-e-rec (- k 2))))
        (t (+ (numerator-of-convergence-for-e-rec (1- k))
              (numerator-of-convergence-for-e-rec (- k 2))))))

which works for small k but gets pretty slow for k = 100, obviously.

I have no real idea how to transform this function to a version with could be tail-call optimized. I have seen a pattern using two accumulating variables for fibonacci numbers but fail to transform this pattern to my function.

Is there a general guideline how to transform complex recursions to tco versions or should I implement an iterative solution directly.?


Solution

  • First, note that memoization is probably the simplest way optimize your code: it does not reverse the flow of operations; you call your function with a given k and it goes back to zero to compute the previous values, but with a cache. If however you want to turn your function from recursive to iterative with TCO, you'll have to compute things from zero up to k and pretend you have a constant-sized stack / memory.

    Step function

    First, write a function which computes current n given k, n-1 and n-2:

    (defun n (k n1 n2)
      (if (plusp k)
          (case k
            (1 2)
            (2 3)
            (t (multiple-value-bind (quotient remainder) (floor k 3)
                 (if (zerop remainder)
                     (+ (* 2 quotient n1) n2)
                     (+ n1 n2)))))
          0))
    

    This step should be easy; here, I rewrote your function a little bit but I actually only extracted the part that computes n given the previous n and k.

    Modified function with recursive (iterative) calls

    Now, you need to call n from k starting from 0 to the maximal value you want to be computed, named m hereafter. Thus, I am going to add a parameter m, which controls when the recursive call stops, and call n recursively with the modified arguments. You can see the arguments are shifted, current n1 is the next n2, etc.

    (defun f (m k n1 n2)
      (if (< m k)
          n1
          (if (plusp k)
            (case k
              (1 (f m (1+ k) 2 n1))
              (2 (f m (1+ k) 3 n1))
              (t (multiple-value-bind (quotient remainder) (floor k 3)
               (if (zerop remainder)
                 (f m (1+ k) (+ (* 2 quotient n1) n2) n1)
                 (f m (1+ k) (+ n1 n2) n1)))))
            (f m (1+ k) 0 n1))))
    

    That's all, except that you don't want to show this interface to your user. The actual function g properly bootstraps the initial call to f:

    (defun g (m)
      (f m 0 0 0))
    

    The trace for this function exhibits an arrow ">" shape, which is the case with tail-recursive functions (tracing is likely to inhibit tail-call optimization):

      0: (G 5)
        1: (F 5 0 0 0)
          2: (F 5 1 0 0)
            3: (F 5 2 2 0)
              4: (F 5 3 3 2)
                5: (F 5 4 8 3)
                  6: (F 5 5 11 8)
                    7: (F 5 6 19 11)
                    7: F returned 19
                  6: F returned 19
                5: F returned 19
              4: F returned 19
            3: F returned 19
          2: F returned 19
        1: F returned 19
      0: G returned 19
    19
    

    Driver function with a loop

    The part that can be slightly difficult, or make your code hard to read, is when we inject tail-recursive calls inside the original function n. I think it is better to use a loop instead, because:

    1. unlike with the tail-recursive call, you can guarantee that the code will behave as you wish, without worrying whether your implementation will actually optimize tail-calls or not.
    2. the code for the step function n is simpler and only expresses what is happening, instead of detailing how (tail-recursive calls are just an implementation detail here).

    With the above function n, you can change g to:

    (defun g (m)
      (loop
         for k from 0 to m
         for n2 = 0 then n1
         for n1 = 0 then n
         for n = (n k n1 n2)
         finally (return n)))
    

    Is there a general guideline how to transform complex recursions to tco versions or should I implement an iterative solution directly?

    Find a step function which advances the computation from the base case to the general case, and put intermediate variables as parameters, in particular results from past calls. This function can call itself (in which case it will be tail-recursive, because you have to compute all the arguments first), or simply called in a loop. You have to be careful when computing the initial values, you might have more corner cases than with a simple recursive function.

     See also

    Scheme's named let, the RECUR macro in Common Lisp and the recur special form in Clojure.