Search code examples
pythondecoratorwrapper

How to write python decorator that updates keyword argument?


The goal is to write a decorator that updates one keyword argument of the wrapped function. In the following code wrapper attempts to update kwarg1:

import inspect                                                                                                                                                                                         [0/300678]
from functools import wraps

def override_me(arg, kwarg1="default kwarg1", kwarg2="default kwarg2"):
    print(f"override_me {arg} kwarg1={kwarg1} kwarg2={kwarg2}")

def append_kwarg1(func):
    original_kwarg1_default = (
        inspect.signature(func).parameters["kwarg1"].default
    )

    @wraps(func)
    def wrapper(*args, kwarg1=original_kwarg1_default, **kwargs):
        func(*args, kwarg1=kwarg1 + "_patched!", **kwargs)

    return wrapper

override_me = append_kwarg1(override_me)


override_me("passed_arg")
override_me("passed_arg", kwarg1="passed_kwarg1_named")
override_me("passed_arg", "passed_kwarg1_as_arg") # TypeError: override_me() got multiple values for argument 'kwarg1'

This, however, fails when kwarg1 is passed as a positional argument.

Edit: Clarifications as pointed in comments: override_me signature cannot be changed (think: external module).


Solution

  • A working solution based on the fact that even a keyword argument has an index. Then depending on the length of args passed by caller we can determine if the argument of interest has been passed positionally or by keyword and update it either in args or kwargs:

    
    import inspect                                                                                                                                                                                         
    from functools import wraps
    
    def override_me(arg, kwarg1="default kwarg1", kwarg2="default kwarg2"):
        print(f"override_me {arg} kwarg1={kwarg1} kwarg2={kwarg2}")
    
    def append_kwarg1(func):
        params = inspect.signature(func).parameters
        kwarg1_index = next(
            x[0] for x in zip(range(len(params)), params.items()) if x[1][0] == "kwarg1"
        )
    
        def update(v):
            return v + "_patched!"
    
        @wraps(func)
        def wrapper(*args, **kwargs):
            if len(args) > kwarg1_index:
                args = (
                    args[:kwarg1_index]
                    + (update(args[kwarg1_index]),)
                    + args[kwarg1_index + 1 :]
                )
                func(*args, **kwargs)
            else:
                kwargs["kwarg1"] = update(kwargs.get("kwarg1", params["kwarg1"].default))
                func(*args, **kwargs)
    
        return wrapper
    
    
    override_me = append_kwarg1(override_me)
    
    override_me("passed_arg", kwarg1="passed_kwarg1_named")
    override_me("passed_arg")
    override_me("passed_arg", "passed_kwarg1_as_arg")
    override_me("passed_arg", "passed_kwarg1_as_arg", "passed_kwarg2_as_arg")