pythonclassdecoratorwrapper

How to decorate all functions of a class without typing it over and over for each method?


Lets say my class has many methods, and I want to apply my decorator on each one of them, later when I add new methods, I want the same decorator to be applied, but I don't want to write @mydecorator above the method declaration all the time.

If I look into __call__ is that the right way to go?


I'd like to show this way, which is a similar solution to my problem for anybody finding this question later, using a mixin as mentioned in the comments.

class WrapinMixin(object):
    def __call__(self, hey, you, *args):
        print 'entering', hey, you, repr(args)
        try:
            ret = getattr(self, hey)(you, *args)
            return ret
        except:
            ret = str(e)
            raise
        finally:
            print 'leaving', hey, repr(ret)
    

Then you can in another

class Wrapmymethodsaround(WrapinMixin): 
    def __call__:
         return super(Wrapmymethodsaround, self).__call__(hey, you, *args)

Editor's note: this example appears to be solving a different problem than what is asked about.


Solution

  • Decorate the class with a function that walks through the class's attributes and decorates callables. This may be the wrong thing to do if you have class variables that may happen to be callable, and will also decorate nested classes (credits to Sven Marnach for pointing this out) but generally it's a rather clean and simple solution. Example implementation (note that this will not exclude special methods (__init__ etc.), which may or may not be desired):

    def for_all_methods(decorator):
        def decorate(cls):
            for attr in cls.__dict__: # there's propably a better way to do this
                if callable(getattr(cls, attr)):
                    setattr(cls, attr, decorator(getattr(cls, attr)))
            return cls
        return decorate
    

    Use like this:

    @for_all_methods(mydecorator)
    class C(object):
        def m1(self): pass
        def m2(self, x): pass
        ...