Search code examples
python-3.xcachingredisfastapidecorator

Cache decorator for FastAPI


I have a FastAPI route.

from functools import wraps

@router.get("/query/", status_code=200,)
async def query(q: str):

    # Grab info from database using query string.

    return res

The above is simplified but is accurate to what I generally want to do with various functions in my FastAPI project. I've created the following Python decorator, I believe this is what it should be but I'm not sure.

def cache(func):

    @wraps(func)
    def wrapper(*args, **kwargs):
        # Cache URL

    return wrapper

I'm using a Redis instance to cache the URLs from my FastAPI routes. Though, I can't figure out how to get the URL from the @router decorator and use the q parameter from my function. So for example, how can I get the "/query/" part and the value of q: str and use it my decorator. There might be a way through storing the URL in a variable and maybe passing it to each decorator, but I'm looking for the intended way to do things. In other words, how can I do this most concisely without changing all my functions too much?


Solution

  • Explanation

    In order to cache responses for a server, there needs to be a unique key for every response.

    In this case, the __name__ property of the FastAPI route, along with its arguments, should do for any caching purposes.

    Configure Redis

    First and foremost, you need Redis configured. The following example uses the environment variables REDIS_SERVER_URL and REDIS_PORT, but the defaults would probably be localhost and 6379.

    import redis
    
    store = redis.Redis(
        host=REDIS_SERVER_URL,
        port=REDIS_PORT,
    )
    

    Wrapper Function

    Then, you can create a wrapper function with functools (Python standard library) to use as a decorator.

    The following decorator takes one input - the amount of hours to cache.

    from functools import wraps
    
    def cache(_, hours=2):
        def wrapper(func):
            @wraps(func)
            def wrapped(*args, **kwargs):
                key_parts = [func.__name__] + list(args)
                key = "-".join(str(k) for k in key_parts)
                result = store.get(key)
    
                if result is None:
                    value = func(*args, **kwargs)
                    value_json = json.dumps(value)
                    expire_time = 60 * 60 * hours
                    store.setex(key, expire_time, value_json)
                else:
                    value_json = result.decode("utf-8")
                    value = json.loads(value_json)
    
                return value
            return wrapped
        return wrapper
    

    Caching

    Finally, you can use the decorator to cache your routes. The following example caches the root function for two hours (default).

    @cache
    @router.get("/", status_code=200,)
    async def root():
        return {"message": "Hello World!"}
    

    Async Caching

    For my specific use case, I also needed caching for coroutine functions. To do that, you need slightly different code to run things asynchronously. In the following wrapper, there is a conditional to handle both async functions and normal functions.

    from inspect import iscoroutinefunction
    
    def cache(_, hours=2):
        def wrapper(func):
            @wraps(func)
            async def wrapped(*args, **kwargs):
                key_parts = [func.__name__] + list(args)
                key = "-".join(str(k) for k in key_parts)
                result = store.get(key)
    
                if result is None:
                    is_coroutine = iscoroutinefunction(func)
                    if is_coroutine:
                        value = await func(*args, **kwargs)
                    else:
                        value = func(*args, **kwargs)
                    value_json = json.dumps(value)
                    expire_time = 60 * 60 * hours
                    store.setex(key, expire_time, value_json) # Storing the output in Redis
                else:
                    value_json = result.decode("utf-8")
                    value = json.loads(value_json)
    
                return value
            return wrapped
        return wrapper