Search code examples
recursionclojurecontinuationscontinuation-passing

Continuation-passing-style does not seem to make a difference in Clojure


I have been experimenting with continuation-passing style, as I may need to deal with some non tail-recursive functions soon. Good technique to know, in any case! I wrote a test function, both in Lua and in Clojure; ran the Lua one on my little Android handheld.

The Lua version seems to have worked fine, Lua's stack already has a depth of about 300000, but with CPS, I was easily able to do over 7000000 iterations before the system blew up, probably out of lack of memory, rather than any limitation of the CPS/Lua combination.

The Clojure attempt fared less well. With little over 1000 iterations it was complaining of blown stack, it can do better just with plain iteration, which has a stack of about 1600, iirc.

Any ideas what might be the problem? Something inherent to the JVM, perhaps, or just some silly noob error? (Oh, BTW, the test function, sigma(log) was chosen because it grows slowly, and Lua does not support bignums on Android)

All ideas, hints, suggestions most welcome.

The Clojure code:

user=> (defn cps2 [op]
  #_=>   (fn [a b k] (k (op a b))))
#'user/cps2

user=> (defn cps-sigma [n k]
  #_=>  ((cps2 =) n 1 (fn [b]
  #_=>           (if b                    ; growing continuation
  #_=>               (k 0)                ; in the recursive call
  #_=>               ((cps2 -) n 1 (fn [nm1]
  #_=>                        (cps-sigma nm1 (fn [f]
  #_=>                                          ((cps2 +) (Math/log n) f k)))))))))
#'user/cps-sigma

user=> (cps-sigma 1000 identity)
5912.128178488171

user=> (cps-sigma 1500 identity)

StackOverflowError   clojure.lang.Numbers.equal (Numbers.java:216)
user=> 

===================

PS. After experimenting a bit, I tried the code I mention in my third comment, below

(defn mk-cps [accept? end-value kend kont]
  (fn [n]
  ((fn [n k]
    (let [cont (fn [v] (k (kont v n)))]
      (if (accept? n)
        (k end-value)
        (recur (dec n) cont))))
    n kend)))

(def sigmaln-cps (mk-cps zero? 0 identity #(+  %1 (Math/log %2)))) 

user=> (sigmaln-cps 11819) ;; #11819 iterations first try

StackOverflowError   clojure.lang.RT.doubleCast (RT.java:1312)

That's obviously better, by an order, however I still think it's way too low. Technically it should be limited only by memory, yes?

I mean the toy Lua system, on a toy Android tablet did over 7000000...


Solution

  • Clojure has the trampoline function that can remove a lot of the confusing plumbing involved in this problem:

    (defn sigma [n]
      (letfn [(sig [curr n]
                (if (<= n 1)
                  curr
                  #(sig (+ curr (Math/log n)) (dec n))))]
        (trampoline sig 0 n)))
    
    (sigma 1000)
    => 5912.128178488164
    (sigma 1500)
    => 9474.406184917756
    (sigma 1e7) ;; might take a few seconds
    => 1.511809654875759E8
    

    The function you pass to trampoline can either return a new function, in which case the trampoline continues "bouncing", or a non-function value which would be a "final" value. This example doesn't involve mutually recursive functions, but those are also doable with trampoline.