I'm having trouble finding a solution to this problem. Whenever we decorate a method of a class, the method is not yet bound to any instance, so say we have:
from functools import wraps
def decorator(f):
closure_variable = 0
@wraps(f)
def wrapper(*args, **kwargs):
nonlocal closure_variable
closure_variable += 1
print(closure_variable)
f(*args, **kwargs)
return
return wrapper
class ClassA:
@decorator
def decorated_method(self):
pass
This leads to something funny, which is all instances of ClassA
are bound to the same closure environment.
inst1 = ClassA()
inst2 = ClassA()
inst3 = ClassA()
inst1.decorated_method()
inst2.decorated_method()
inst3.decorated_method()
The above lines will output:
1
2
3
Now to my issue at hand, I had created a decorator which caches a token and only requests a new one once it expires. This decorator was applied to a method of a class called TokenSupplier
.
I realized this behavior and I clearly don't want this to happen, can I solve this issue and keep the decorator design pattern here?
I thought of storing a dictionary in the closure environment and using the instance hash to index the desired data but I believe I might be simply missing something more fundamental.
My goal would be to have each instance having it's own closure environment but still being able to use a decorator pattern to decorate different future TokenSupplier
implementations.
Thank you in advance!
In order to avoid sharing the cache across all instances, which may not be required or desired, it is best to have a cache for each instance with expiry time, etc. In other words, we don't need to have a "single source cache" for all instances.
In the following implementation, each and every instance of a class initializes its own cache dict()
to store the token, its expiration time and other relevant info, that will give you the full control.
from functools import wraps
import time
class TokenCacheDecorator:
def __init__(self, get_token_func):
self.get_token_func = get_token_func
def __get__(self, inst, owner):
if inst is None:
return self
@wraps(self.get_token_func)
def wrapper(*args, **kwargs):
if not hasattr(inst, '_token_cache') or inst._token_cache['expiration_time'] < time.time():
print(f"[{id(inst)}] Cache miss")
token, expires_in = self.get_token_func(inst, *args, **kwargs)
inst._token_cache = {
'token': token,
'expiration_time': time.time() + expires_in
}
print(f"[{id(inst)}] New token - {token} expiration time: {inst._token_cache['expiration_time']}")
return inst._token_cache['token']
return wrapper
class ClassA:
def __init__(self, token, expires_in):
self.token = token
self.expires_in = expires_in
self._token_cache = {'token': None, 'expiration_time': 0}
@TokenCacheDecorator
def get_token(self):
return self.token, self.expires_in
inst1 = ClassA("token1", 2)
inst2 = ClassA("token2", 2)
inst3 = ClassA("token3", 2)
print(inst1.get_token())
print(inst2.get_token())
print(inst3.get_token())
time.sleep(3)
print(inst1.get_token())
print(inst2.get_token())
print(inst3.get_token())
[4439687776] Cache miss
[4439687776] New token - token1 expiration time: 1716693215.503801
token1
[4440899024] Cache miss
[4440899024] New token - token2 expiration time: 1716693215.503846
token2
[4440899360] Cache miss
[4440899360] New token - token3 expiration time: 1716693215.503862
token3
[4439687776] Cache miss
[4439687776] New token - token1 expiration time: 1716693218.5076532
token1
[4440899024] Cache miss
[4440899024] New token - token2 expiration time: 1716693218.50767
token2
[4440899360] Cache miss
[4440899360] New token - token3 expiration time: 1716693218.507679
token3