I have a base class with two classes that are derived from it. I want the methods of the base classes to behave differently depending on whether the arguments are of the same type as the derived class, or only instances of the base class but a different type. This is the current implementation:
class MyBase:
def __init__(self, foo: int):
self.foo = foo
def __eq__(self, other):
return self.foo == other.foo
class MyDerived_1(MyBase):
def __init__(self, foo: int, bar: int):
super().__init__(foo)
self.bar = bar
class MyDerived_2(MyBase):
def __init__(self, foo: int, bar: int):
super().__init__(foo)
self.bar = bar
def __eq__(self, other):
if type(other) == type(self):
return self.bar == other.bar
elif isinstance(other, MyBase):
return super().__eq__(other)
else:
return False
In the fourth last line I have to reference MyBase explicitly. Perhaps this is fine but my understanding is that a main point of the "super" keyword is that it should allow you to change the base class without having to re-write anything in the class. So I.e. a potential issue with this solution is that if MyBase is changed then init will be fine because it calls "super", but eq will not update its behaviour.
So I attempted replacing "MyBase" with "type(super)" or "type(super())", but these do not reference the super class, they reference the class of the object "super".
Note that this questions differs from:
Get parent class name? Get defining class of unbound method object in Python 3 etc.
Because they are looking for the parent classes once the object has been initialised.
I guess that I should be able to find the super class by running up the MRO. But this seems like a bad solution given that I'm not looking for the whole inheritance tree, I just want to know the type of the super class.
Is there a way to pull that information out of "super"?
First of all, you want to return NotImplemented
from __eq__
when you encounter a type you don't support, so that Python can also give the second operand a chance to participate in the equality test. From the Python datamodel documenation:
Numeric methods and rich comparison methods should return this value if they do not implement the operation for the operands provided. (The interpreter will then try the reflected operation, or some other fallback, depending on the operator.)
Your code should really just delegate to super().__eq__()
when other
is not an instance of the same type, there is no need to test for the base type here; the base class should already take care of testing for the right type or protocol.
Next, you could make use of the Python 3 __class__
closure to access the class that a method is defined on; Python adds this closure whenever you use either super()
or __class__
in a function definition that is nested inside of a class definition:
class MyBase:
# ...
def __eq__(self, other):
if not isinstance(other, __class__):
# we can't handle the other type, inform Python
return NotImplemented
return self.foo == other.foo
class MyDerived_2(MyBase):
# ...
def __eq__(self, other):
if isinstance(other, __class__):
# if other is an instance of MyDerived_2, only test for 'bar'
return self.bar == other.bar
# otherwise fall back to the base behaviour
return super().__eq__(other)
Note that I used isinstance()
rather than type()
tests, you'd want subclasses of MyDerived_2
to inherit this behaviour.
Instead of testing for a specific class hierarchy, you could also rely on duck-typing; if the other object has the right attribute names, then just assume it can be used to compare with:
class MyBase:
# ...
def __eq__(self, other):
try:
self.foo == other.foo
except AttributeError:
# we can't handle the other type, inform Python
return NotImplemented
class MyDerived_2(MyBase):
# ...
def __eq__(self, other):
try:
self.bar == other.bar
except AttributeError:
# otherwise fall back to the base behaviour
return super().__eq__(other)