Search code examples
pythondecoratorfunctools

Enforcing keyword-only arguments in decorated functions


I have a class with several methods that require a certain argument be present, but for different reasons.

Typically, the argument will be attached to the instance as an attribute, in which case there is no need for the argument to be passed. However, if the attribute was missing (or None) this argument could be optionally passed as a keyword-only argument:

import functools

class Foo:
    def __init__(self, this_kwarg_default=None):
        self.default = this_kwarg_default
    
    @staticmethod
    def require_this_kwarg(reason):
        def enforced(func):
            @functools.wraps(func)
            def wrapped(self, *args, this_kwarg=None, **kwargs):
                if this_kwarg is None:
                    this_kwarg = self.default
                if this_kwarg is None:
                    raise TypeError(f'You need to pass this kwarg, {reason}!')
                return func(self, *args, this_kwarg=this_kwarg, **kwargs)
        
            return wrapped
        return enforced

    require_this_kwarg = require_this_kwarg.__func__

    @require_this_kwarg('because I said so')
    def foo(self, this_kwarg=None):
        print(f'This kwarg is {str(this_kwarg)}')

Mostly, this gives the desired behavior.

>>> myfoo = Foo(42)
>>> myfoo.foo()
This kwarg is 42
>>> myfoo.foo(this_kwarg=4)
This kwarg is 4
>>> yourfoo = Foo()
>>> yourfoo.foo()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "dec.py", line 15, in wrapped
    raise TypeError(f'You need to pass this kwarg, {reason}!')
TypeError: You need to pass this kwarg, because I said so!

But if any positional argument is passed, I get some unexpected behavior:

>>> myfoo.foo(4)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "dec.py", line 16, in wrapped
    return func(self, *args, this_kwarg=this_kwarg, **kwargs)
TypeError: foo() got multiple values for argument 'this_kwarg'

It would make sense, then to define Foo.foo to take this_kwarg as a keyword-only argument:

@require_this_kwarg('because I said so')
def foo(self, *, this_kwarg=None):
    print(f'This kwarg is {str(this_kwarg)}')

However...

>>> myfoo.foo(4)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "dec.py", line 16, in wrapped
    return func(self, *args, this_kwarg=this_kwarg, **kwargs)
TypeError: foo() takes 1 positional argument but 2 positional arguments (and 1 keyword-only argument) were given

In this case, the desired behavior would be to raise TypeError: foo() takes 0 positional arguments but 1 was given, just as would be expected if no decorator were used.

My hope was that functools.wraps would enforce the call signature of the decorated function. Obviously, though, this is not what wraps does. Is there some way to achieve this?


Solution

  • Wow, that was way trickier than I expected. I'd be interested to see if someone comes up with a simpler and cleaner solution, but I think this does what you need?

    from inspect import getfullargspec
    import functools
    
    
    class Foo:
        def __init__(self, x_default):
            self.default = x_default
    
        @staticmethod
        def require_x(reason):
            def enforced(func):
                @functools.wraps(func)
                def wrapped(self, *args, **kwargs):
                    argspec = getfullargspec(func)
                    while True:
                        if 'x' in kwargs:
                            # it's explicitly there, so it will have a value
                            if kwargs['x'] is None:
                                kwargs['x'] = self.default
                            break
                        elif argspec.varargs is None:
                            # there are no varargs to eat up positional arguments
                            if 'x' in argspec.args[:len(args)+1]:
                                # x will get a value from args, offset by one for self
                                if args[argspec.args.index('x') - 1] is None:
                                    args = tuple(a if n != argspec.args.index('x') - 1 else self.default
                                                 for n, a in enumerate(args))
                                break
                            elif argspec.defaults is not None and 'x' in argspec.args[-len(argspec.defaults):]:
                                # x will get a value from a default
                                if argspec.defaults[argspec.args[-len(argspec.defaults):].index('x')] is None:
                                    kwargs['x'] = self.default
                                break
                        elif 'x' in argspec.kwonlydefaults:
                            if argspec.kwonlydefaults['x'] is None:
                                kwargs['x'] = self.default
                            break
                        raise TypeError(f'{func.__name__} needs a value for x, {reason}.')
    
                    func(self, *args, **kwargs)
    
                return wrapped
    
            return enforced
    
        require_x = require_x.__func__
    

    I don't like production code that needs inspect to work, so I'm still dubious of whether you really need code that does this - there is probably a bit of an anti-pattern in the broader design here. But anything can be done, I suppose.