Search code examples
pythonpython-decoratorsmemoization

Memoization of Imported Functions


I'm creating a decorator to illustrate memoization. Per most folks, I'm using a recursively defined Fibonacci function.

I understand that naming the memoized version of the function differently than the original will result in inefficiency because the recursive calls will activate the unmemoized function. (See this old question, Memoization python function)

My issue is I can't seem to find the correct syntax to overwrite the name an imported function.

from fibonacci import fibonacci

def with_memoization(function):
    past_results = {}

    def function_with_memoization(*args):
        if args not in past_results:
            past_results[args] = function(*args)
        return past_results[args]
    return function_with_memoization


def fib(n):
    if n == 0:
        return 0
    elif n == 1:
        return 1
    else:
        return fib(n-1) + fib(n-2)


fib = with_memoization(fib)
fibonacci = with_memoization(fibonacci)

print(fib(100)) # completes in <1 second
print(fibonacci(100)) # completes in >2 minutes, probably hours

The imported fibonacci function and the fib function here are identical. What am I missing?


Solution

  • The from module import function statement aliases the function from the module as function. Therefore, when it gets decorated, only the alias is decorated. The recursive calls are to the unaliased function (in the module) i.e. the undecorated one.

    You can think of this as creating a partial memory, the aliased function will remember the results of it's own calculations, but not the intermediate steps. In the code above, fibonacci(100) will be the only entry in dictionary when it's done. (Don't wait up for it.)

    Using import module syntax doesn't alias the function, module.function is it's 'real' name. Therefore, decorations applied to fibonacci.fibonacci will also decorate the function that's being recursively called.

    Working implementation:

    import fibonacci
    
    def with_memoization(function):
    
        past_results = {}
    
        def function_with_memoization(*args, **kwargs):
            if args not in past_results:
                past_results[args] = function(*args, **kwargs)
            return past_results[args]
        return function_with_memoization
    
    
    fibonacci.fibonacci = with_memoization(fibonacci.fibonacci)
    
    print(fibonacci.fibonacci(100))