Search code examples
pythondecoratorpython-decorators

How to provide *args and **kwargs to __call__ using abbreviated notation of class based decorators?


I have a question about class based decorators in python and how arguments can be provided while decorating a given function.

Consider the following prototype:

class MyDecoratorClass:
    def __init__(self, *args, **kwargs):
        # . . . Any code that is needed here
        pass

    def __call__(self, fun_to_decorate, *args, **kwargs):
        def inner(*args, **kwargs):
            # . . . Any code that is needed here
            return fun_to_decorate(*args, **kwargs)
        return inner

And consider the function to be decorated:

def myfun(*args, **kwargs):
    print("myfun was called")

One way of decorating the above mentioned function would be the following:

  • Step 1: Creating an instance of the decorator class. This will call the constructor, that is, __init__ . If I pass any argument below, it would go "only once" inside __init__ , when the instance is created.

    decorator = MyDecoratorClass()

  • Step 2: Specifying which function to be decorated, and if required provide an argument list which will correspond to *args and/or **kwargs in __call__ signature as it has been defined in the class.

decorated_fun = decorator(myfun, arg1, arg2)

  • Step 3: Finally call the decorated function and if required provide an argument list for the function itself, which if provided, corresponds to the list of the formal parameters of the inner function inside the __call__ function. If I understand correctly, the namespace of inner formal parameters are distinct from the namespace of __call__ formal parameters

    decorated_fun(...)

And this works. However if I want to use the abbreviated notation I have to write something like this:

@MyDecoratorClass(arg1, arg2)
def myfun(*args, **kwargs):
    print("myfun was called")

In the example above I passed the very same two positional arguments arg1 and arg2. Now the problem is, that these two arguments are not (*args, **kwargs) specified in __call__ but actually those specified in __init__

So my question is, when we use the abbreviated notation for decorating a function via a class based decoration with arguments, how to access call argument list in order to provide those arguments for each decorated function?


Solution

  • About decorator signature

    Doing this...

    decorated_fun = decorator(myfun, arg1, arg2)
    

    ... is not how decorators work. A decorator takes a single argument, a function, and outputs a function.

    In particular, the above is not equivalent to the correct way to do which is this.

    @decorator(arg1, arg2)
    def myfun(...):
        ...
    
    # Which is actually equivalent to...
    myfun = decorator(arg1, arg2)(myfun)
    

    Accessing __init__ arguments in __call__

    If you want to have access to the argument passed to __init__, you could store them as instance attributes.

    class MyDecoratorClass:
        def __init__(self, *args, **kwargs):
            # We store the arguments
            self.args = args
            self.kwargs = kwargs
    
            # Any other code that is needed here
    
    
        def __call__(self, fun_to_decorate): # Single argument
            def inner(*args, **kwargs):
    
                # Here you have access to self.args and self.kwargs
    
                return fun_to_decorate(*args, **kwargs)
    
            return inner
    

    Doing the same with a function

    Although, keep in mind that you do not need to have a class decorator to accomplish this, the following decorator will have the same behaviour.

    def my_decorator_function(*args, **kwargs):
    
        def wrapper(decorated_func):
    
            def inner_wrapper(*inner_args, **inner_kwargs):
                # args and kwargs can be accessed here
                return decorated_func(*inner_args, **inner_kwargs)
    
            return inner_wrapper
    
        return wrapper
    
    @my_decorator_function(arg1, arg2)
    def myfun(*args, **kwargs):
        print("myfun was called")