Search code examples
pythonfunctools

Why does 'functools.cache' and 'functools.lru_cache' not working for inner function inside class method?


Why are functools.cache and functools.lru_cache not working for an inner function inside a class method? Are there any workaround without installing 3rd-party packages and without moving the inner function somewhere to the outer scope?

from functools import cache

class Sample:

    def outer(self, a):
        @cache
        def inner(b):
            print(f"Inner function is called! b: '{b}'")
            return b

        return inner(a)


sample = Sample()
sample.outer(100)
sample.outer(100)
sample.outer(100)

Output:

Inner function is called! b: '100'
Inner function is called! b: '100'
Inner function is called! b: '100'

I want the inner function to be called only once for the same argument.


Solution

  • Presumably the inner function gets "re-created" every time outer is called, causing the caching to no longer work. This is corroborated by this example, which calls inner multiple times from outer:

    from functools import cache
    
    class Sample:
    
        def outer(self, a):
            @cache
            def inner(b):
                print(f"Inner function is called! b: '{b}'")
                return b
    
            for i in range(3):
                print(f"Inner call #{i}")
                print(inner(a))
            return 
    
    sample = Sample()
    sample.outer(100)
    sample.outer(100)
    sample.outer(100)
    

    Which outputs:

    Inner call #0
    Inner function is called! b: '100'
    100
    Inner call #1
    100
    Inner call #2
    100
    Inner call #0
    Inner function is called! b: '100'
    100
    Inner call #1
    100
    Inner call #2
    100
    Inner call #0
    Inner function is called! b: '100'
    100
    Inner call #1
    100
    Inner call #2
    100
    

    So, the caching does work, but only until the inner class is "re-created". A very simple solution is to avoid using an inner function, but relying on a static the following:

    from functools import cache
    
    class Sample:
    
        @staticmethod
        @cache
        def _inner(b):
            print(f"Inner function is called! b: '{b}'")
            return b
    
        def outer(self, a):
            for i in range(3):
                print(f"Inner call #{i}")
                print(Sample._inner(a))
            return 
    
    sample = Sample()
    sample.outer(100)
    sample.outer(100)
    sample.outer(100)
    

    Note that I've renamed inner to _inner, to indicate that _inner is not meant to be exposed to users of your class. Furthermore, I've turned it into a static method. This allows you to use Sample._inner(...) rather than self._inner(...). After all, the self parameter isn't used in _inner, so it can be a static function.

    This outputs (the desired):

    Inner call #0
    Inner function is called! b: '100'
    100
    Inner call #1
    100
    Inner call #2
    100
    Inner call #0
    100
    Inner call #1
    100
    Inner call #2
    100
    Inner call #0
    100
    Inner call #1
    100
    Inner call #2
    100
    

    Which indicates that everything was cached properly for the different outer calls.