Search code examples
clojure

Is there a simpler way to memoize a recursive let fn?


Let's say you have a recursive function defined in a let block:

(let [fib (fn fib [n]
            (if (< n 2)
              n
              (+ (fib (- n 1))
                 (fib (- n 2)))))]
  (fib 42))

This can be mechanically transformed to utilize memoize:

  1. Wrap the fn form in a call to memoize.
  2. Move the function name in as the 1st argument.
  3. Pass the function into itself wherever it is called.
  4. Rebind the function symbol to do the same using partial.

Transforming the above code leads to:

(let [fib (memoize
            (fn [fib n]
              (if (< n 2)
                n
                (+ (fib fib (- n 1))
                   (fib fib (- n 2))))))
      fib (partial fib fib)]
  (fib 42))

This works, but feels overly complicated. The question is: Is there a simpler way?


Solution

  • I take risks in answering since I am not a scholar but I don't think so. You pretty much did the standard thing which in fine is a partial application of memoization through a fixed point combinator.

    You could try to fiddle with macros though (for simple cases it could be easy, syntax-quote would do name resolution for you and you could operate on that). I'll try once I get home.

    edit: went back home and tried out stuff, this seems to be ok-ish for simple cases

    (defmacro memoize-rec [form]
      (let [[fn* fname params & body] form
            params-with-fname (vec (cons fname params))]
        `(let [f# (memoize (fn ~params-with-fname
                             (let [~fname (partial ~fname ~fname)] ~@body)))]
           (partial f# f#))))
    
    ;; (clojure.pprint/pprint (macroexpand '(memoize-rec (fn f [x] (str (f x))))))
    ((memoize-rec (fn fib [n]
                    (if (< n 2)
                      n
                      (+ (fib (- n 1))
                         (fib (- n 2)))))) 75) ;; instant
    
    ((fn fib [n]
                    (if (< n 2)
                      n
                      (+ (fib (- n 1))
                         (fib (- n 2))))) 75) ;; slooooooow
    

    simpler than what i thought!