Memoization is a powerful tool. I try to understand the fundamental mechanism, but it seems not to work in the way how I thought. Could anyone explain how it works in the following code in detail?
def memoize(f):
memo = {}
def helper(x):
if x not in memo:
memo[x] = f(x)
print(memo)
return memo[x]
return helper
@memoize
def fib(n):
if n < 2:
return n
else:
return fib(n-1) + fib(n-2)
What really confuses me is when the decorator memoize
comes into work in this example. According to a tutorial, it seems the whole function to be decorated runs in the decorator. Here the function is fib(n)
. If so, how is the loop in the fib(n)
handled in decorator memoize(f)
?
Let's take fib(4)
as an example to demystify the process:
In [1]: fib(4)
{1: 1}
{1: 1, 0: 0}
{1: 1, 0: 0, 2: 1}
{1: 1, 0: 0, 2: 1}
{1: 1, 0: 0, 2: 1, 3: 2}
{1: 1, 0: 0, 2: 1, 3: 2}
{1: 1, 0: 0, 2: 1, 3: 2, 4: 3}
Why is the first value printed out in memoize(f)
is {1: 1}
? I expect memoize(f)
to store memo = {4 : f(4)} at the very beginning even though the the value of f(4) was not known yet at that moment. I know I was wrong. Could anyone one explain how we get these output and how the loop in fib(n)
works in memoize(f)
?
Thanks a lot.
The memo
cache doesn't get populated until the function call returns:
memo[x] = f(x)
Since the loop is recursive, there are a bunch more calls to f
before that first f(4)
finishes returning and populates the cache. The first one of those calls to actually return is f(1)
, followed by f(0)
, etc (as seen in your print statements).
If you were to add another print
at the start of helper
(before you call f
) then you'd see the recursive calls as a sandwich, with f(4)
starting first but finishing last.
Here's how you could modify the print statements to show the recursion depth as well:
def memoize(f):
memo = {}
depth = [0]
def helper(x):
print(f"{' '*depth[0]}Calling f({x})...")
depth[0] += 1
if x not in memo:
memo[x] = f(x)
print(f"{' '*depth[0]}Cached: {memo}")
depth[0] -= 1
print(f"{' '*depth[0]}Finished f({x})!")
return memo[x]
return helper
@memoize
def fib(n):
if n < 2:
return n
else:
return fib(n-1) + fib(n-2)
prints:
Calling f(4)...
Calling f(3)...
Calling f(2)...
Calling f(1)...
Cached: {1: 1}
Finished f(1)!
Calling f(0)...
Cached: {1: 1, 0: 0}
Finished f(0)!
Cached: {1: 1, 0: 0, 2: 1}
Finished f(2)!
Calling f(1)...
Cached: {1: 1, 0: 0, 2: 1}
Finished f(1)!
Cached: {1: 1, 0: 0, 2: 1, 3: 2}
Finished f(3)!
Calling f(2)...
Cached: {1: 1, 0: 0, 2: 1, 3: 2}
Finished f(2)!
Cached: {1: 1, 0: 0, 2: 1, 3: 2, 4: 3}
Finished f(4)!