Search code examples
pythonthread-localcontextmanager

How to put variables on the stack/context in Python


In essence, I want to put a variable on the stack, that will be reachable by all calls below that part on the stack until the block exits. In Java I would solve this using a static thread local with support methods, that then could be accessed from methods.

Typical example: you get a request, and open a database connection. Until the request is complete, you want all code to use this database connection. After finishing and closing the request, you close the database connection.

What I need this for, is a report generator. Each report consist of multiple parts, each part can rely on different calculations, sometimes different parts relies in part on the same calculation. As I don't want to repeat heavy calculations, I need to cache them. My idea is to decorate methods with a cache decorator. The cache creates an id based on the method name and module, and it's arguments, looks if it has this allready calculated in a stack variable, and executes the method if not.

I will try and clearify by showing my current implementation. Want I want to do is to simplify the code for those implementing calculations.

First, I have the central cache access object, which I call MathContext:

class MathContext(object):
    def __init__(self, fn): 
        self.fn = fn
        self.cache = dict()
    def get(self, calc_config):
        id = create_id(calc_config)
        if id not in self.cache:
            self.cache[id] = calc_config.exec(self)
        return self.cache[id]

The fn argument is the filename the context is created in relation to, from where data can be read to be calculated.

Then we have the Calculation class:

 class CalcBase(object):
     def exec(self, math_context):
         raise NotImplementedError

And here is a stupid Fibonacci example. Non of the methods are actually recursive, they work on large sets of data instead, but it works to demonstrate how you would depend on other calculations:

class Fibonacci(CalcBase):
    def __init__(self, n): self.n = n
    def exec(self, math_context):
        if self.n < 2: return 1
        a = math_context.get(Fibonacci(self.n-1))
        b = math_context.get(Fibonacci(self.n-2))
        return a+b

What I want Fibonacci to be instead, is just a decorated method:

@cache
def fib(n):
    if n<2: return 1
    return fib(n-1)+fib(n-2)

With the math_context example, when math_context goes out of scope, so does all it's cached values. I want the same thing for the decorator. Ie. at point X, everything cached by @cache is dereferrenced to be gced.


Solution

  • I went ahead and made something that might just do what you want. It can be used as both a decorator and a context manager:

    from __future__ import with_statement
    try:
        import cPickle as pickle
    except ImportError:
        import pickle
    
    
    class cached(object):
        """Decorator/context manager for caching function call results.
        All results are cached in one dictionary that is shared by all cached
        functions.
    
        To use this as a decorator:
            @cached
            def function(...):
                ...
    
        The results returned by a decorated function are not cleared from the
        cache until decorated_function.clear_my_cache() or cached.clear_cache()
        is called
    
        To use this as a context manager:
    
            with cached(function) as function:
                ...
                function(...)
                ...
    
        The function's return values will be cleared from the cache when the
        with block ends
    
        To clear all cached results, call the cached.clear_cache() class method
        """
    
        _CACHE = {}
    
        def __init__(self, fn):
            self._fn = fn
    
        def __call__(self, *args, **kwds):
            key = self._cache_key(*args, **kwds)
            function_cache = self._CACHE.setdefault(self._fn, {})
            try:
                return function_cache[key]
            except KeyError:
                function_cache[key] = result = self._fn(*args, **kwds)
                return result
    
        def clear_my_cache(self):
            """Clear the cache for a decorated function
            """
            try:
                del self._CACHE[self._fn]
            except KeyError:
                pass # no cached results
    
        def __enter__(self):
            return self
    
        def __exit__(self, type, value, traceback):
            self.clear_my_cache()
    
        def _cache_key(self, *args, **kwds):
            """Create a cache key for the given positional and keyword
            arguments. pickle.dumps() is used because there could be
            unhashable objects in the arguments, but passing them to 
            pickle.dumps() will result in a string, which is always hashable.
    
            I used this to make the cached class as generic as possible. Depending
            on your requirements, other key generating techniques may be more
            efficient
            """
            return pickle.dumps((args, sorted(kwds.items())), pickle.HIGHEST_PROTOCOL)
    
        @classmethod
        def clear_cache(cls):
            """Clear everything from all functions from the cache
            """
            cls._CACHE = {}
    
    
    if __name__ == '__main__':
        # used as decorator
        @cached
        def fibonacci(n):
            print "calculating fibonacci(%d)" % n
            if n == 0:
                return 0
            if n == 1:
                return 1
            return fibonacci(n - 1) + fibonacci(n - 2)
    
        for n in xrange(10):
            print 'fibonacci(%d) = %d' % (n, fibonacci(n))
    
    
        def lucas(n):
            print "calculating lucas(%d)" % n
            if n == 0:
                return 2
            if n == 1:
                return 1
            return lucas(n - 1) + lucas(n - 2)
    
        # used as context manager
        with cached(lucas) as lucas:
            for i in xrange(10):
                print 'lucas(%d) = %d' % (i, lucas(i))
    
        for n in xrange(9, -1, -1):
            print 'fibonacci(%d) = %d' % (n, fibonacci(n))
    
        cached.clear_cache()
    
        for n in xrange(9, -1, -1):
            print 'fibonacci(%d) = %d' % (n, fibonacci(n))