Search code examples
pythoncachingmemoization

Memoization/Caching with default optional arguments


I'd like to make a python decorator which memoizes functions. For example, if

@memoization_decorator    
def add(a, b, negative=False):
    print "Computing"
    return (a + b) * (1 if negative is False else -1)

add(1, 2)
add(1, b=2)
add(1, 2, negative=False)
add(1, b=2, negative=False)
add(a=1, b=2, negative=False)
add(a=1, b=2)

I'd like the output to be

Computing
3
3
3
3
3
3

and the output should be the same under any permutation of the last 6 lines.

This amounts to finding a map sending equivalent sets of *args, **kwargs** to a unique key for the memoization cache dict. THe above example has *args, **kwargs equal to

(1, 2), {}
(1,), {'b': 2}
(1, 2), {'negative': False}
(1,), {'b': 2, 'negative': False}
(), {'a': 1, 'b': 2, 'negative': False}
(), {'a': 1, 'b': 2}

Solution

  • For the memoization you can use functools.lru_cache().

    Edit: The problem with this for your use case is that it does not consider two function calls to be the same if the way they specify their arguments differs. To address this we can write our own decorator which sits on top of lru_cache() and transforms arguments into a single canonical form:

    from functools import lru_cache, wraps
    import inspect
    
    def canonicalize_args(f):
        """Wrapper for functools.lru_cache() to canonicalize default                                                          
        and keyword arguments so cache hits are maximized."""
    
        @wraps(f)
        def wrapper(*args, **kwargs):
            sig = inspect.getfullargspec(f.__wrapped__)
    
            # build newargs by filling in defaults, args, kwargs                                                            
            newargs = [None] * len(sig.args)
            newargs[-len(sig.defaults):] = sig.defaults
            newargs[:len(args)] = args
            for name, value in kwargs.items():
                newargs[sig.args.index(name)] = value
    
            return f(*newargs)
    
        return wrapper
    
    @canonicalize_args
    @lru_cache()
    def add(a, b, negative=False):
        print("Computing")
        return (a + b) * (1 if negative is False else -1)
    

    Now add() is called only once for the entire set of calls in the question. Every call is made with all three arguments specified positionally.