Search code examples
pythoninstancedecoratorpython-decoratorsself

Function decorator that will return self?


I have to following class that will make an object with chainable methods that derive from class variables. Since this code is quite repetitive, my challenge is to make a decorator that can apply over method a, b and c. The problem I am facing is that I cannot seem to find a way to construct a wrapper that will return the instance (self). Is there a better way to construct this?

class Test:

    def __init__(self, a, b, c):
        self._a = a
        self._b = b
        self._c = c

        self.call_chain = []


    def a(self, truth):
        def func():
            return self._a == truth
        self.call_chain.append(func)
        return self


    def b(self, truth):
        def func():
            return self._b == truth
        self.call_chain.append(func)
        return self


    def c(self, val):
        def func():
            return self._c == val
        self.call_chain.append(func)
        return self


    def evaluate(self):
        try:
            for f in self.call_chain:
                if f() == False:
                    raise ValueError('False encountered')
        except ValueError:
            self.call_chain.clear()
            return False 
        self.call_chain.clear()
        return True

It works chained like this:

c = Test(True, False, 13)
c.a(True).b(False).c(13).evaluate()

Solution

  • The trick is to store the arguments to the function as part of the call chain. The easiest way is to use functools.partial objects.

    from functools import wraps, partial
    
    def chain(func):
        @wraps(func)
        def wrapper(self, *args, **kwargs):
            suspended = partial(func, self, *args, **kwargs)
            self.call_chain.append(suspended)
            return self
        return wrapper
    
    class Test:
        def __init__(self, a, b, c):
            self.call_chain = []
            self._a = a
            self._b = b
            self._c = c
        @chain
        def a(self, val):
            return self._a == val
        @chain
        def b(self, val):
            return self._b == val
        @chain
        def c(self, val):
            return self._c == val
        def evaluate(self):
            try:
                for f in self.call_chain:
                    if f() == False:
                        raise ValueError('False encountered')
            except ValueError:
                self.call_chain.clear()
                return False 
            self.call_chain.clear()
            return True
    
    c = Test(True, False, 13)
    c.a(True).b(False).c(13).evaluate()  # True
    c.a(True).b(False).c(11).evaluate()  # False