Search code examples
pythonoopmultiple-inheritancemethod-resolution-order

Walk MRO for Python special methods returning NotImplemented


I have a hierarchy of classes for algebraic objects that implement special methods such as __mul__ and __add__, and use multiple inheritance. I was somehow assuming that Python (>= 3.5) would walk the method-resolution order (mro) to find the first method that does not return NotImplemented. Alas, this does not seem to be the case. Consider the following minimal example:

class A():
    def __mul__(self, other):
        return "A * %s" % other

class B():
    def __mul__(self, other):
        if isinstance(other, int):
            return "B * %s" % other
        else:
            return NotImplemented

class C(B, A):
    pass

class D(B, A):
    def __mul__(self, other):
        res = B.__mul__(self, other)
        if res is NotImplemented:
            res = A.__mul__(self, other)
        return res

In this code, I have implemented D with the desired behavior:

>>> d = D()
>>> d * 1
'B * 1'
>>> d * "x"
'A * x'

However, I actually would have expected C to behave the same as D, which it does not:

>>> c = C()
>>> c * 1
'B * 1'
>>> c * "x"
Traceback (most recent call last):
File "<ipython-input-23-549ffa5b5ffb>", line 1, in <module>
    c * "x"
TypeError: can't multiply sequence by non-int of type 'C'

I understand what's happening, of course: I'm just returning the result of the first matching method in the mro (I'd just hoped that NotImplemented would be handled as a special value)

My question is if there's any way to get around writing the boilerplate code like D.__mul__ (which would be basically the same for all the numerical special methods, for all the classes). I suppose I could write a class decorator or metaclass to automatically generate all these methods, but I was hoping there would be some easier (standard library) way, or alternatively, that someone has already done something like this.


Solution

  • Python walks up MRO when you ask it to, it's not implicit to keep checking higher. Change your code to use cooperative inheritance with super() (a request to walk the MRO to the next class up) when you would otherwise return NotImplemented and it should work. It removes the need for either C or D to define __mul__ at all, since they don't add anything to its functionality:

    class A():
        def __mul__(self, other):
            return "A * %s" % other
    
    class B():
        def __mul__(self, other):
            if isinstance(other, int):
                return "B * %s" % other
            try:
                return super().__mul__(other)  # Delegate to next class in MRO
            except AttributeError:
                return NotImplemented  # If no other class to delegate to, NotImplemented
    
    class C(B, A):
        pass
    
    class D(B, A):
        pass  # Look ma, no __mul__!
    

    Then testing:

    >>> d = D()
    >>> d * 1
    'B * 1'
    >>> d * 'x'
    'A * x'
    

    The magic of super() is that it works even in multiple inheritance scenarios where one class, B in this case, knows nothing about A, but will still happily delegate to it (or any other available class) if a child happens to inherit from both. If it doesn't, we handle the resulting AttributeError to make the result NotImplemented, as before, so stuff like this works as expected (it tries str's __rmul__ which doesn't recognize non-int and explodes):

    >>> class E(B): pass
    >>> e = E()
    >>> e * 1
    'B * 1'
    >>> e * 'x'
    Traceback (most recent call last)
    ...
    TypeError: can't multiply sequence by non-int of type 'E'