How do I get mypy
to recognize that an argument of a function needs to be a subclass of a particular base class? Consider the following:
# main.py
class A: ...
class B(A):
def f(self, x):
print(x)
class C(A):
def f(self, x):
print(x, "in C")
# test.py
def call_f(instance):
instance.f("Hello")
if __name__=="__main__":
from main import B, C
b = B()
call_f(b)
c = C()
call_f(c)
As shown in main.py
, all subclasses of A
implement a method f
. call_f
in test.py
takes an instance of one of the subclasses of A
and calls this method. An example of this is shown in the if __name__ == "__main__":
section of test.py
.
One way to type hint the definition in test.py
would be the following:
# test_typed.py
from typing import Union
from main import B, C
def call_f(instance: Union[B, C]) -> None:
instance.f("Hello")
if __name__=="__main__":
from main import B, C
b = B()
call_f(b)
c = C()
call_f(c)
However, the disadvantage here is that I have to keep adding every new subclass of A
into the function annotation of call_f
which seems repetitive.
Is there a better way to do this?
I suppose that taking a stab at an answer isn't a bad idea at this point, since I have enough information from the comments.
You need to first enforce that f
is implemented by subclasses of A
. Otherwise, you could implement a subclass that doesn't implement f
, and static typechecking would (rightfully) point out that there is nothing preventing that from occurring. You could use Union[B, C]
if you only wanted some subclasses to implement f
, but you've already stated that this is undesirable for extensibility reasons.
What you should do is have the function accept instances of the superclass A
, and raise an error whenever f
as defined in the superclass is invoked:
from abc import ABC, abstractmethod
class A(ABC):
@abstractmethod
def f(self, x):
raise NotImplementedError("This method should be defined in subclasses of A.")
class B(A):
def f(self, x):
print(x)
class C(A):
def f(self, x):
print(x, "in C")
Then, call_f()
would look like the following:
def call_f(instance: A) -> None:
instance.f("Hello")