Search code examples
pythonmemoizationfunctools

Python - how do I memoize a partial object?


I have a set of functions that take integers and functions as arguments. I'd like to memoize them.

I know that using this solution, I could use pickle to encode both sets of arguments and memoize the encoded values. In this particular use case, however, the function arguments are large and constant, and I'd rather not take up the lru_cache space with a function argument that, in the course of a program run, never changes.

Is there a way for me to memoize a partial function, where I've fixed the function arguments and have received a partial object that takes only hashable arguments? I can't figure out how to use the functools.lru_cache decorator as a function.

Here's what I've tried on a toy example. It doesn't work; the binomial tree still revisits nodes.

import functools
import logging


logging.basicConfig(level=logging.DEBUG)


def binomial_tree(x, y, fn):
    logging.debug(f"binomial_tree({x}, {y})")
    """Note: this does not recombine, and we can't memoize function."""
    if x == 10:
        return fn(x, y)
    else:
        return 0.5 * binomial_tree(x + 1, y, fn) + 0.5 * binomial_tree(x + 1, y + 1, fn)


def memoize_fn(fn):
    @functools.lru_cache(maxsize=None)
    def inner(*args, **kwargs):
        return fn(*args, **kwargs)
    return inner

memoized_binomial_tree = memoize_fn(functools.partial(binomial_tree, fn=lambda x, y: 10 * x * y))
print(memoized_binomial_tree(0, 0))

Solution

  • Here is a way to memoize your toy example with binomial_tree without encoding and memoizing function arguments:

    import functools
    import logging
    
    
    logging.basicConfig(level=logging.DEBUG)
    
    
    def create_binomial_tree(fn):
        @functools.lru_cache(maxsize=None)
        def binomial_tree(x, y):
            logging.debug(f"binomial_tree({x}, {y})")
            if x == 10:
                return fn(x, y)
            else:
                return 0.5 * binomial_tree(x + 1, y) + 0.5 * binomial_tree(x + 1, y + 1)
        return binomial_tree
    
    
    memoized_binomial_tree = create_binomial_tree(fn=lambda x, y: 10 * x * y)
    print(memoized_binomial_tree(0, 0))
    

    Maybe it can be applicable in your real use case?