Search code examples
pythoninheritancesupersuperclass

Get the type of the super class in Python 3


I have a base class with two classes that are derived from it. I want the methods of the base classes to behave differently depending on whether the arguments are of the same type as the derived class, or only instances of the base class but a different type. This is the current implementation:

class MyBase:
    def __init__(self, foo: int):
        self.foo = foo 

    def __eq__(self, other):
        return self.foo == other.foo 


class MyDerived_1(MyBase):
    def __init__(self, foo: int, bar: int):
        super().__init__(foo)
        self.bar = bar


class MyDerived_2(MyBase):
    def __init__(self, foo: int, bar: int):
        super().__init__(foo)
        self.bar = bar 

    def __eq__(self, other):
        if type(other) == type(self):
            return self.bar == other.bar 
        elif isinstance(other, MyBase):
            return super().__eq__(other)
        else:
            return False

In the fourth last line I have to reference MyBase explicitly. Perhaps this is fine but my understanding is that a main point of the "super" keyword is that it should allow you to change the base class without having to re-write anything in the class. So I.e. a potential issue with this solution is that if MyBase is changed then init will be fine because it calls "super", but eq will not update its behaviour.

So I attempted replacing "MyBase" with "type(super)" or "type(super())", but these do not reference the super class, they reference the class of the object "super".

Note that this questions differs from:

Get parent class name? Get defining class of unbound method object in Python 3 etc.

Because they are looking for the parent classes once the object has been initialised.

I guess that I should be able to find the super class by running up the MRO. But this seems like a bad solution given that I'm not looking for the whole inheritance tree, I just want to know the type of the super class.

Is there a way to pull that information out of "super"?


Solution

  • First of all, you want to return NotImplemented from __eq__ when you encounter a type you don't support, so that Python can also give the second operand a chance to participate in the equality test. From the Python datamodel documenation:

    Numeric methods and rich comparison methods should return this value if they do not implement the operation for the operands provided. (The interpreter will then try the reflected operation, or some other fallback, depending on the operator.)

    Your code should really just delegate to super().__eq__() when other is not an instance of the same type, there is no need to test for the base type here; the base class should already take care of testing for the right type or protocol.

    Next, you could make use of the Python 3 __class__ closure to access the class that a method is defined on; Python adds this closure whenever you use either super() or __class__ in a function definition that is nested inside of a class definition:

    class MyBase:
        # ...
    
        def __eq__(self, other):
            if not isinstance(other, __class__):
                # we can't handle the other type, inform Python
                return NotImplemented
            return self.foo == other.foo 
    
    class MyDerived_2(MyBase):
        # ...
    
        def __eq__(self, other):
            if isinstance(other, __class__):
                # if other is an instance of MyDerived_2, only test for 'bar'
                return self.bar == other.bar 
            # otherwise fall back to the base behaviour
            return super().__eq__(other)
    

    Note that I used isinstance() rather than type() tests, you'd want subclasses of MyDerived_2 to inherit this behaviour.

    Instead of testing for a specific class hierarchy, you could also rely on duck-typing; if the other object has the right attribute names, then just assume it can be used to compare with:

    class MyBase:
        # ...
    
        def __eq__(self, other):
            try:
                self.foo == other.foo
            except AttributeError:
                # we can't handle the other type, inform Python
                return NotImplemented
    
    class MyDerived_2(MyBase):
        # ...
    
        def __eq__(self, other):
            try:
                self.bar == other.bar
            except AttributeError:
                # otherwise fall back to the base behaviour
                return super().__eq__(other)