While decorators accept arguments, their values are parsed when the interpreter sees the declarations of the underlying functions, and thus remain "constant" during the function calls.
The problem is how to modify the decorator argument values during runtime, when you are using a decorator from a library or other third party code. My problem is specifically, how to decide when I will allow parallel runs on numba
's njit
decorator which accepts this as a boolean argument.
This question has been asked a lot, but everyone implicitly assumes that we are in control of the decorator's source code. In this question, @C2H5OH's answer actually solves the problem, but the fact that it works with "external" decorators goes virtually unnoticed, so anyone facing the same problem has to search quite a lot.
So this question is intended to be self-answered in order to pinpoint this difference and to show how it can be done with numba specifically.
Let's assume we have to work with the following decorator that prints the value of its argument before the function call but we cannot modify its source:
def decorator_with_argument(a):
def actual_decorator(func):
def function_wrapper(*args, **kwargs):
print(f"Before function call. a={a}")
func(*args, **kwargs)
return function_wrapper
return actual_decorator
I would want to do something like this:
a = 1
@decorator_with_argument(a)
def foo():
print("Hi")
>>> foo()
# Before function call. a=1
# Hi
>>> a = 2
>>> foo()
# Before function call. a=1 <---- I would like it to print 2
# Hi
a small example is the following:
import numba as nb
parallel = True
@nb.njit(parallel=parallel)
def parallel_test(A):
s = 0
for i in nb.prange(A.shape[0]):
s += A[i]
return s
And by changing the value of parallel
I would like to allow/stop parallel execution.
The workaround to the problem is to encapsulate the declaration of the decorated function inside (yet) another wrapper function that handles the "setup". The decorator's argument can be passed as an argument to the outer "setup" function, which will actually redeclare the same functions (but different decorator behavior) each time.
Now there's a scope issue to solve, since our original function is now nested. For that, we can simply return the decorated function or assign it to a global variable.
The code solution:
def function_setup(a):
@decorator_with_argument(a)
def _foo():
print("Hi")
>>> foo = function_setup(1)
>>> foo()
# Before function call. a=1
# Hello
>>> foo = function_setup(2)
>>> foo()
# Before function call. a=2
# Hello
For the numba case:
def function_setup(parallel):
@nb.njit(parallel=parallel)
def _parallel_test(A):
s = 0
for i in nb.prange(A.shape[0]):
s += A[i]
return s
return _parallel_test
import numpy as np
from tqdm import tqdm
A = np.random.random(100)
parallel_test = function_setup(parallel=True)
for _ in tqdm(range(1000000)):
parallel_test(A)
# 100%|██████████| 1000000/1000000 [00:04<00:00, 230644.00it/s]
parallel_test = function_setup(parallel=False)
for _ in tqdm(range(1000000)):
parallel_test(A)
# 100%|██████████| 1000000/1000000 [00:00<00:00, 1741571.80it/s]
Supposedly, one would like to enable/disable parallel run to all njit functions together, so all them will have to be included in the function_setup wrapper and assign to module-level variables.