Search code examples
pythonc++recursionstack-overflow

Why is Python recursion so expensive and what can we do about it?


Suppose we want to compute some Fibonacci numbers, modulo 997.

For n=500 in C++ we can run

#include <iostream>
#include <array>

std::array<int, 2> fib(unsigned n) {
    if (!n)
        return {1, 1};
    auto x = fib(n - 1);
    return {(x[0] + x[1]) % 997, (x[0] + 2 * x[1]) % 997};
}

int main() {
    std::cout << fib(500)[0];
}

and in Python

def fib(n):
    if n==1:
        return (1, 2)
    x=fib(n-1)
    return ((x[0]+x[1]) % 997, (x[0]+2*x[1]) % 997)

if __name__=='__main__':
    print(fib(500)[0])

Both will find the answer 996 without issues. We are taking modulos to keep the output size reasonable and using pairs to avoid exponential branching.

For n=5000, the C++ code outputs 783, but Python will complain

RecursionError: maximum recursion depth exceeded in comparison

If we add a couple of lines

import sys

def fib(n):
    if n==1:
        return (1, 2)
    x=fib(n-1)
    return ((x[0]+x[1]) % 997, (x[0]+2*x[1]) % 997)

if __name__=='__main__':
    sys.setrecursionlimit(5000)
    print(fib(5000)[0])

then Python too will give the right answer.

For n=50000 C++ finds the answer 151 within milliseconds while Python crashes (at least on my machine).

Why are recursive calls so much cheaper in C++? Can we somehow modify the Python compiler to make it more receptive to recursion?

Of course, one solution is to replace recursion with iteration. For Fibonacci numbers, this is easy to do. However, this will swap the initial and the terminal conditions, and the latter is tricky for many problems (e.g. alpha–beta pruning). So generally, this will require a lot of hard work on the part of the programmer.


Solution

  • A solution is a trampoline: the recursive function, instead of calling another function, returns a function that makes that call with the appropriate arguments. There's a loop one level higher that calls all those functions in a loop until we have the final result. I'm probably not explaining it very well; you can find resources online that do a better job.

    The point is that this converts recursion to iteration. I don't think this is faster, maybe it's even slower, but the recursion depth stays low.

    An implementation could look like below. I split the pair x into a and b for clarity. I then converted the recursive function to a version that keeps track of a and b as arguments, making it tail recursive.

    def fib_acc(n, a, b):
        if n == 1:
            return (a, b)
        return lambda: fib_acc(n - 1, (a+b) % 997, (a+2*b) % 997)
    
    def fib(n):
        x = fib_acc(n, 1, 2)
        while callable(x):
            x = x()
        return x
    
    if __name__=='__main__':
        print(fib(50000)[0])