Search code examples
pythonpython-3.xpython-decoratorslru

Make built-in lru_cache skip caching when function returns None


Here's a simplified function for which I'm trying to add a lru_cache for -

from functools import lru_cache, wraps

@lru_cache(maxsize=1000)
def validate_token(token):
    if token % 3:
        return None
    return True

for x in range(1000):
    validate_token(x)

print(validate_token.cache_info())

outputs -

CacheInfo(hits=0, misses=1000, maxsize=1000, currsize=1000)

As we can see, it would also cache args and returned values for the None returns as well. In above example, I want the cache_size to be 334, where we are returning non-None values. In my case, my function having large no. of args might return a different value if previous value was None. So I want to avoid caching the None values.

I want to avoid reinventing the wheel and implementing a lru_cache again from scratch. Is there any good way to do this?

Here are some of my attempts -

1. Trying to implement own cache (which is non-lru here) -

from functools import wraps 

# global cache object
MY_CACHE = {}

def get_func_hash(func):
    # generates unique key for a function. TODO: fix what if function gets redefined?
    return func.__module__ + '|' + func.__name__

def my_lru_cache(func):
    name = get_func_hash(func)
    if not name in MY_CACHE:
        MY_CACHE[name] = {}
    @wraps(func)
    def function_wrapper(*args, **kwargs):
        if tuple(args) in MY_CACHE[name]:
            return MY_CACHE[name][tuple(args)]
        value = func(*args, **kwargs)
        if value is not None:
            MY_CACHE[name][tuple(args)] = value
        return value
    return function_wrapper

@my_lru_cache
def validate_token(token):
    if token % 3:
        return None
    return True

for x in range(1000):
    validate_token(x)

print(get_func_hash(validate_token))
print(len(MY_CACHE[get_func_hash(validate_token)]))

outputs -

__main__|validate_token
334

2. I realised that the lru_cache doesn't do caching when an exception is raised within the wrapped function -

from functools import wraps, lru_cache

def my_lru_cache(func):
    @wraps(func)
    @lru_cache(maxsize=1000)
    def function_wrapper(*args, **kwargs):
        value = func(*args, **kwargs)
        if value is None:
            # TODO: change this to a custom exception
            raise KeyError
        return value
    return function_wrapper

def handle_exception(func):
    @wraps(func)
    def function_wrapper(*args, **kwargs):
        try:
            value = func(*args, **kwargs)
            return value
        except KeyError:
            return None
    return function_wrapper    

@handle_exception
@my_lru_cache
def validate_token(token):
    if token % 3:
        return None
    return True

for x in range(1000):
    validate_token(x)

print(validate_token.__wrapped__.cache_info())

outputs -

CacheInfo(hits=0, misses=334, maxsize=1000, currsize=334)

Above correctly caches only the 334 values, but needs wrapping the function twice and accessing the cache_info in a weird manner func.__wrapped__.cache_info().

How do I better achieve the behaviour of not caching when None(or specific) values are returned using built-in lru_cache decorator in a pythonic way?


Solution

  • You are missing the two lines marked here:

    def handle_exception(func):
        @wraps(func)
        def function_wrapper(*args, **kwargs):
            try:
                value = func(*args, **kwargs)
                return value
            except KeyError:
                return None
    
        function_wrapper.cache_info = func.cache_info    # Add this
        function_wrapper.cache_clear = func.cache_clear  # Add this
        return function_wrapper
    

    You can do both wrappers in one function:

    def my_lru_cache(maxsize=128, typed=False):
        class CustomException(Exception):
            pass
    
        def decorator(func):
            @lru_cache(maxsize=maxsize, typed=typed)
            def raise_exception_wrapper(*args, **kwargs):
                value = func(*args, **kwargs)
                if value is None:
                    raise CustomException
                return value
    
            @wraps(func)
            def handle_exception_wrapper(*args, **kwargs):
                try:
                    return raise_exception_wrapper(*args, **kwargs)
                except CustomException:
                    return None
    
            handle_exception_wrapper.cache_info = raise_exception_wrapper.cache_info
            handle_exception_wrapper.cache_clear = raise_exception_wrapper.cache_clear
            return handle_exception_wrapper
    
        if callable(maxsize):
            user_function, maxsize = maxsize, 128
            return decorator(user_function)
    
        return decorator