Search code examples
pythondecoratormemoization

How does loop work in decorator (memoization)


Memoization is a powerful tool. I try to understand the fundamental mechanism, but it seems not to work in the way how I thought. Could anyone explain how it works in the following code in detail?

def memoize(f):
    memo = {}
    def helper(x):
        if x not in memo:            
            memo[x] = f(x)
        print(memo)
        return memo[x]

    return helper

@memoize
def fib(n):
    if n < 2:
        return n
    else:
        return fib(n-1) + fib(n-2)

What really confuses me is when the decorator memoize comes into work in this example. According to a tutorial, it seems the whole function to be decorated runs in the decorator. Here the function is fib(n). If so, how is the loop in the fib(n) handled in decorator memoize(f)?

Let's take fib(4) as an example to demystify the process:

In [1]: fib(4)
{1: 1}
{1: 1, 0: 0}
{1: 1, 0: 0, 2: 1}
{1: 1, 0: 0, 2: 1}
{1: 1, 0: 0, 2: 1, 3: 2}
{1: 1, 0: 0, 2: 1, 3: 2}
{1: 1, 0: 0, 2: 1, 3: 2, 4: 3}

Why is the first value printed out in memoize(f) is {1: 1}? I expect memoize(f)to store memo = {4 : f(4)} at the very beginning even though the the value of f(4) was not known yet at that moment. I know I was wrong. Could anyone one explain how we get these output and how the loop in fib(n) works in memoize(f)?

Thanks a lot.


Solution

  • The memo cache doesn't get populated until the function call returns:

    memo[x] = f(x)

    Since the loop is recursive, there are a bunch more calls to f before that first f(4) finishes returning and populates the cache. The first one of those calls to actually return is f(1), followed by f(0), etc (as seen in your print statements).

    If you were to add another print at the start of helper (before you call f) then you'd see the recursive calls as a sandwich, with f(4) starting first but finishing last.

    Here's how you could modify the print statements to show the recursion depth as well:

    def memoize(f):
        memo = {}
        depth = [0]
        def helper(x):
            print(f"{'  '*depth[0]}Calling f({x})...")
            depth[0] += 1
            if x not in memo:            
                memo[x] = f(x)
            print(f"{'  '*depth[0]}Cached: {memo}")
            depth[0] -= 1
            print(f"{'  '*depth[0]}Finished f({x})!")
            return memo[x]
    
        return helper
    
    @memoize
    def fib(n):
        if n < 2:
            return n
        else:
            return fib(n-1) + fib(n-2)
    

    prints:

    Calling f(4)...
      Calling f(3)...
        Calling f(2)...
          Calling f(1)...
            Cached: {1: 1}
          Finished f(1)!
          Calling f(0)...
            Cached: {1: 1, 0: 0}
          Finished f(0)!
          Cached: {1: 1, 0: 0, 2: 1}
        Finished f(2)!
        Calling f(1)...
          Cached: {1: 1, 0: 0, 2: 1}
        Finished f(1)!
        Cached: {1: 1, 0: 0, 2: 1, 3: 2}
      Finished f(3)!
      Calling f(2)...
        Cached: {1: 1, 0: 0, 2: 1, 3: 2}
      Finished f(2)!
      Cached: {1: 1, 0: 0, 2: 1, 3: 2, 4: 3}
    Finished f(4)!