Search code examples
pythontime-complexityspace-complexity

lru_cache vs dynamic programming, stackoverflow with one but not with the other?


I'm doing this basic dp (Dynamic Programming) problem on trees (https://cses.fi/problemset/task/1674/). Given the structure of a company (hierarchy is a tree), the task is to calculate for each employee the number of their subordinates.

This:

import sys
from functools import lru_cache  # noqa
sys.setrecursionlimit(2 * 10 ** 9)

if __name__ == "__main__":
    n: int = 200000
    boss: list[int] = list(range(1, 200001))  
    # so in my example it will be a tree with every parent having one child
    graph: list[list[int]] = [[] for _ in range(n)]

    for i in range(n-1):
        graph[boss[i] - 1].append(i+1)  # directed so neighbours of a node are only its children

    @lru_cache(None)
    def dfs(v: int) -> int:
        if len(graph[v]) == 0:
            return 0
        else:
            s: int = 0
            for u in graph[v]:
                s += dfs(u) + 1
            return s

    dfs(0)

    print(*(dfs(i) for i in range(n)))

crashes (I googled the error message and it means stack overflow)

Process finished with exit code -1073741571 (0xC00000FD)

HOWEVER

import sys
sys.setrecursionlimit(2 * 10 ** 9)

if __name__ == "__main__":
    n: int = 200000
    boss: list[int] = list(range(1, 200001))
    # so in my example it will be a tree with every parent having one child
    graph: list[list[int]] = [[] for _ in range(n)]

    for i in range(n-1):
        graph[boss[i] - 1].append(i+1)  # directed so neighbours of a node are only its children

    dp: list[int] = [0 for _ in range(n)]
    def dfs(v: int) -> None:
        if len(graph[v]) == 0:
            dp[v] = 0
        else:
            for u in graph[v]:
                dfs(u)
                dp[v] += dp[u] + 1

    dfs(0)

    print(*dp)

doesn't and it's exactly the same complexity right? The dfs goes exactly as deep in both situations too? I tried to make the two pieces of code as similar as I could.

I tried 20000000 instead of 200000 (i.e. graph 100 times deeper) and it still doesn't stackoverflow for the second option. Obviously I could do an iterative version of it but I'm trying to understand the underlying reason why there are such a big difference between those two recursive options so that I can learn more about Python and its underlying functionning.

I'm using Python 3.11.1.


Solution

  • lru_cache is implemented in C, its calls are interleaved with your function's calls, and your C code recursion is too deep and crashes. Your second program only has deep Python code recursion, not deep C code recursion, avoiding the issue.

    In Python 3.11 I get a similar bad crash:

    [Execution complete with exit code -11]
    

    In Python 3.12 I just get an error:

    Traceback (most recent call last):
      File "/ATO/code", line 34, in <module>
        dfs(0)
      File "/ATO/code", line 31, in dfs
        s += dfs(u) + 1
             ^^^^^^
      File "/ATO/code", line 31, in dfs
        s += dfs(u) + 1
             ^^^^^^
      File "/ATO/code", line 31, in dfs
        s += dfs(u) + 1
             ^^^^^^
      [Previous line repeated 496 more times]
    RecursionError: maximum recursion depth exceeded
    

    That's despite your sys.setrecursionlimit(2 * 10 ** 9).

    What’s New In Python 3.12 explains:

    sys.setrecursionlimit() and sys.getrecursionlimit(). The recursion limit now applies only to Python code. Builtin functions do not use the recursion limit, but are protected by a different mechanism that prevents recursion from causing a virtual machine crash

    So in 3.11, your huge limit is applied to C recursion as well, Python obediently attempts your deep recursion, its C stack overflows, and the program crashes. Whereas in 3.12 the limit doesn't apply to C recursion, Python protects itself with that different mechanism at a relatively shallow recursion depth, producing that error instead.

    Let's avoid that C recursion. If I use a (simplified) Python version of lru_cache, your first program works fine in both 3.11 and 3.12 without any other change:

    def lru_cache(_):
        def deco(f):
            memo = {}
            def wrap(x):
                if x not in memo:
                    memo[x] = f(x)
                return memo[x]
            return wrap
        return deco
    

    See CPython's GitHub Issue 3.12 setrecursionlimit is ignored in connection with @functools.cache for more details and the current progress on this. There are efforts to increase the C recursion limit, but it looks like it'll remain only in the thousands. Not enough for your super deep recursion.

    Miscellaneous further information I might expand on later:

    If you want the Python version of the regular lru_cache, you could try the hack of importing and deleting the C version of its wrapper before you import the Python version (Technique from this answer, which did that with heapq):

    import _functools
    del _functools._lru_cache_wrapper
    from functools import lru_cache
    

    Or you could remove the import that replaces the Python implementation with the C implementation. But that's an even worse hack and I wouldn't do that. Maybe I would copy&rename the functools module and modify the copy. But I mostly mention these hacks to illustrate what's happening, or as a last resort.