Search code examples
pythonmagic-methods

Is it possible to "catch" magic methods in Python?


Inspired by this question, I thought it'd be interesting to throw together a "MutableNum" class for fun that, in as many cases as possible, acted just like a standard numeric type, but it would be mutable, so something like the following would work:

def double(x): x *= 2

x = MutableNum(9)
print(x)             # 9
double(x)
print(x)             # 18

I got to the following:

class MutableNum():
    val = None
    def __init__(self, v): self.val = v
    # Comparison Methods
    def __eq__(self, x):    return self.val == x
    def __ne__(self, x):    return self.val != x
    def __lt__(self, x):    return self.val <  x
    def __gt__(self, x):    return self.val >  x
    def __le__(self, x):    return self.val <= x
    def __ge__(self, x):    return self.val >= x
    # Arithmetic
    def __mul__(self, x):   return self.__class__(self.val * x)
    def __rmul__(self, x):  return self.__class__(self.val * x)
    # Casts
    def __int__(self):      return self.val
    # Represenation
    def __str__(self):      return "%d" % (self.val)
    def __repr__(self):     return "%s(%d)" % (self.__class__.__name__, self.val)

Which works (so far, as far as I can tell), but I found myself wanting to "catch" the magic methods, since many of them would follow very similar structure.

For example, I'd like to catch __mul__, __add__, __sub__, etc in something like:

def catch(self, method, x): return MutableNum(self.val.method(x))

So for __add__, catch() would return

return MutableNum(self.val.__add__(x))

Is something like this possible? Or should I just implement all of the magic methods like I have already done?

EDIT: I've experimented a little with trying to catch magic methods with __getattr__(self,key), but I'm getting mixed results.

Thanks in advance.

Edit 2

With everyone's help, here's what I came up with:

class MutableNum(object):
    __val__ = None
    def __init__(self, v): self.__val__ = v
    # Comparison Methods
    def __eq__(self, x):        return self.__val__ == x
    def __ne__(self, x):        return self.__val__ != x
    def __lt__(self, x):        return self.__val__ <  x
    def __gt__(self, x):        return self.__val__ >  x
    def __le__(self, x):        return self.__val__ <= x
    def __ge__(self, x):        return self.__val__ >= x
    def __cmp__(self, x):       return 0 if self.__val__ == x else 1 if self.__val__ > 0 else -1
    # Unary Ops
    def __pos__(self):          return self.__class__(+self.__val__)
    def __neg__(self):          return self.__class__(-self.__val__)
    def __abs__(self):          return self.__class__(abs(self.__val__))
    # Bitwise Unary Ops
    def __invert__(self):       return self.__class__(~self.__val__)
    # Arithmetic Binary Ops
    def __add__(self, x):       return self.__class__(self.__val__ + x)
    def __sub__(self, x):       return self.__class__(self.__val__ - x)
    def __mul__(self, x):       return self.__class__(self.__val__ * x)
    def __div__(self, x):       return self.__class__(self.__val__ / x)
    def __mod__(self, x):       return self.__class__(self.__val__ % x)
    def __pow__(self, x):       return self.__class__(self.__val__ ** x)
    def __floordiv__(self, x):  return self.__class__(self.__val__ // x)
    def __divmod__(self, x):    return self.__class__(divmod(self.__val__, x))
    def __truediv__(self, x):   return self.__class__(self.__val__.__truediv__(x))
    # Reflected Arithmetic Binary Ops
    def __radd__(self, x):      return self.__class__(x + self.__val__)
    def __rsub__(self, x):      return self.__class__(x - self.__val__)
    def __rmul__(self, x):      return self.__class__(x * self.__val__)
    def __rdiv__(self, x):      return self.__class__(x / self.__val__)
    def __rmod__(self, x):      return self.__class__(x % self.__val__)
    def __rpow__(self, x):      return self.__class__(x ** self.__val__)
    def __rfloordiv__(self, x): return self.__class__(x // self.__val__)
    def __rdivmod__(self, x):   return self.__class__(divmod(x, self.__val__))
    def __rtruediv__(self, x):  return self.__class__(x.__truediv__(self.__val__))
    # Bitwise Binary Ops
    def __and__(self, x):       return self.__class__(self.__val__ & x)
    def __or__(self, x):        return self.__class__(self.__val__ | x)
    def __xor__(self, x):       return self.__class__(self.__val__ ^ x)
    def __lshift__(self, x):    return self.__class__(self.__val__ << x)
    def __rshift__(self, x):    return self.__class__(self.__val__ >> x)
    # Reflected Bitwise Binary Ops
    def __rand__(self, x):      return self.__class__(x & self.__val__)
    def __ror__(self, x):       return self.__class__(x | self.__val__)
    def __rxor__(self, x):      return self.__class__(x ^ self.__val__)
    def __rlshift__(self, x):   return self.__class__(x << self.__val__)
    def __rrshift__(self, x):   return self.__class__(x >> self.__val__)
    # Compound Assignment
    def __iadd__(self, x):      self.__val__ += x; return self
    def __isub__(self, x):      self.__val__ -= x; return self
    def __imul__(self, x):      self.__val__ *= x; return self
    def __idiv__(self, x):      self.__val__ /= x; return self
    def __imod__(self, x):      self.__val__ %= x; return self
    def __ipow__(self, x):      self.__val__ **= x; return self
    # Casts
    def __nonzero__(self):      return self.__val__ != 0
    def __int__(self):          return self.__val__.__int__()               # XXX
    def __float__(self):        return self.__val__.__float__()             # XXX
    def __long__(self):         return self.__val__.__long__()              # XXX
    # Conversions
    def __oct__(self):          return self.__val__.__oct__()               # XXX
    def __hex__(self):          return self.__val__.__hex__()               # XXX
    def __str__(self):          return self.__val__.__str__()               # XXX
    # Random Ops
    def __index__(self):        return self.__val__.__index__()             # XXX
    def __trunc__(self):        return self.__val__.__trunc__()             # XXX
    def __coerce__(self, x):    return self.__val__.__coerce__(x)
    # Represenation
    def __repr__(self):         return "%s(%d)" % (self.__class__.__name__, self.__val__)
    # Define innertype, a function that returns the type of the inner value self.__val__
    def innertype(self):        return type(self.__val__)
    # Define set, a function that you can use to set the value of the instance
    def set(self, x):
        if   isinstance(x, (int, long, float)): self.__val__ = x
        elif isinstance(x, self.__class__): self.__val__ = x.__val__
        else: raise TypeError("expected a numeric type")
    # Pass anything else along to self.__val__
    def __getattr__(self, attr):
        print("getattr: " + attr)
        return getattr(self.__val__, attr)

I put the entire class, with usage header and rough test suite here.

mgilson's suggestion of using @total_ordering will simplify this a bit.

As long as you follow the usage guidelines (e.g. using x *= 2 instead of x = x * 2), it seems that you'll be fine.

Although, simply wrapping the argument in a list and then modifying x[0] seems much easier -- still was a fun project.


Solution

  • The easiest thing to do is going to be to implement them all by hand. If this was something you were going to add to lots of classes then you might look at metaclasses (can be brain melting) or class decorators (much easier to deal with), but you should do it once by hand so you know what's going on.

    The reason __getattr__ only works sometimes is that it is only called if the name it is looking for cannot be found on the class or any of its base classes. So if __xyz__ can be found on object, __getattr__ will not be called.