Search code examples
pythonstatic-analysistype-hintingmypyduck-typing

How do I get mypy to recognize that an argument of a function needs to be a subclass of a particular base class?


How do I get mypy to recognize that an argument of a function needs to be a subclass of a particular base class? Consider the following:

# main.py
class A: ... 

class B(A):
  def f(self, x):
    print(x)

class C(A):
  def f(self, x):
    print(x, "in C")
# test.py
def call_f(instance):
  instance.f("Hello")

if __name__=="__main__":
  from main import B, C
  b = B()
  call_f(b)
  c = C()
  call_f(c)

As shown in main.py, all subclasses of A implement a method f. call_f in test.py takes an instance of one of the subclasses of A and calls this method. An example of this is shown in the if __name__ == "__main__": section of test.py.

One way to type hint the definition in test.py would be the following:

# test_typed.py
from typing import Union
from main import B, C

def call_f(instance: Union[B, C]) -> None:
  instance.f("Hello")

if __name__=="__main__":
  from main import B, C
  b = B()
  call_f(b)
  c = C()
  call_f(c)

However, the disadvantage here is that I have to keep adding every new subclass of A into the function annotation of call_f which seems repetitive.

Is there a better way to do this?


Solution

  • I suppose that taking a stab at an answer isn't a bad idea at this point, since I have enough information from the comments.

    You need to first enforce that f is implemented by subclasses of A. Otherwise, you could implement a subclass that doesn't implement f, and static typechecking would (rightfully) point out that there is nothing preventing that from occurring. You could use Union[B, C] if you only wanted some subclasses to implement f, but you've already stated that this is undesirable for extensibility reasons.

    What you should do is have the function accept instances of the superclass A, and raise an error whenever f as defined in the superclass is invoked:

    from abc import ABC, abstractmethod
    
    class A(ABC):
      @abstractmethod
      def f(self, x):
        raise NotImplementedError("This method should be defined in subclasses of A.") 
    
    class B(A):
      def f(self, x):
        print(x)
    
    class C(A):
      def f(self, x):
        print(x, "in C")
    

    Then, call_f() would look like the following:

    def call_f(instance: A) -> None:
      instance.f("Hello")