Search code examples
javarecursionstack-overflowtail-call-optimization

Achieving Stackless recursion in Java 8


How do I achieve stackless recursion in Java?

The word that seems to come up the most is "trampolining", and I have no clue what that means.

Could someone IN DETAIL explain how to achieve stackless recursion in Java? Also, what is "trampolining"?

If you cannot provide either of those, could you please point me in the right direction (i.e., a book to read about it or some tutorial that teaches all of these concepts)?


Solution

  • A trampoline is a pattern for turning stack-based recursion into an equivalent loop. Since loops don't add stack frames, this can be thought of as a form of stackless recursion.

    Here's a diagram I found helpful:

    Trampoline diagram

    From bartdesmet.net

    You can think of a trampoline as a process that takes a starting value; iterates on that value; and then exits with the final value.


    Consider this stack-based recursion:

    public static int factorial(final int n) {
        if (n <= 1) {
            return 1;
        }
        return n * factorial(n - 1);
    }
    

    For every recursive call this makes, a new frame is pushed. This is because the prior frame cannot evaluate without the result of the new frame. This will become a problem when the stack gets too deep and we run out of memory.

    Luckily, we can express this function as a loop:

    public static int factorial2(int n) {
        int i = 1; 
        while (n > 1) {
            i = i * n;
            n--;
        }
        return i;
    }
    

    What's going on here? We've taken the recursive step and made it the iteration inside of a loop. We loop until we have completed all recursive steps, storing the result or each iteration in a variable.

    This is more efficient since fewer frames will be created. Instead of storing a frame for each recursive call (n frames), we store the current value and the number of iterations remaining (2 values).

    The generalization of this pattern is a trampoline.

    public class Trampoline<T>
    {
        public T getValue() {
            throw new RuntimeException("Not implemented");
        }
    
        public Optional<Trampoline<T>> nextTrampoline() {
            return Optional.empty();
        }
    
        public final T compute() {
            Trampoline<T> trampoline = this;
    
            while (trampoline.nextTrampoline().isPresent()) {
                trampoline = trampoline.nextTrampoline().get();
            }
    
            return trampoline.getValue();
        }
    }
    

    The Trampoline requires two members:

    • the value of the current step;
    • the next function to compute, or nothing if we have reached the last step

    Any computation that can be described in this way can be "trampolined".

    What does this look like for factorial?

    public final class Factorial
    {
        public static Trampoline<Integer> createTrampoline(final int n, final int sum)
        {
            if (n == 1) {
                return new Trampoline<Integer>() {
                    public Integer getValue() { return sum; }
                };
            }
            
            return new Trampoline<Integer>() {
                public Optional<Trampoline<Integer>> nextTrampoline() {
                    return Optional.of(createTrampoline(n - 1, sum * n));
                }
            };
        }
    }
    

    And to call:

    Factorial.createTrampoline(4, 1).compute()
    

    Notes

    • Boxing will make this inefficient in Java.
    • This code was written on SO; it has not been tested or even compiled

    Further reading