Search code examples
pythoninheritancepython-typingpyright

How can I to get subclass return types when calling functions defined on the base class?


I'm trying to write a class hierarchy in Python so that subclasses can override a method predict to have a more narrow return type which is itself a subclasses of the parent's return type. This seems to work fine when I instantiate an instance of the subclass and call predict; the returned value has the expected narrow type. However, when I call a different function defined on the base class (predict_batch) which itself calls predict, the narrow return type is lost.

Some context: My program has to support using two types of image segmentation models, "instance" and "semantic". The outputs of these two models are very different, so I was thinking to have symmetric class hierarchy to store their outputs (ie. BaseResult, InstResult, and SemResult). This would allow some of the client code to be general by using BaseResults when it doesn't need to know which specific type of model was used.

Here is a toy code example:

from abc import ABC, abstractmethod
from typing import List

from overrides import overrides

##################
# Result classes #
##################


class BaseResult(ABC):
    """Abstract container class for result of image segmentation"""

    pass


class InstResult(BaseResult):
    """Stores the result of instance segmentation"""

    pass


class SemResult(BaseResult):
    """Stores the result of semantic segmentation"""

    pass


#################
# Model classes #
#################


class BaseModel(ABC):
    def predict_batch(self, images: List) -> List[BaseResult]:
        return [self.predict(img) for img in images]

    @abstractmethod
    def predict(self, image) -> BaseResult:
        raise NotImplementedError()


class InstanceSegModel(BaseModel):
    """performs instance segmentation on images"""

    @overrides
    def predict(self, image) -> InstResult:
        return InstResult()


class SemanticSegModel(BaseModel):
    """performs semantic segmentation on images"""

    @overrides
    def predict(self, image) -> SemResult:
        return SemResult()


########
# main #
########

# placeholder for illustration 
images = [None, None, None]

model = InstanceSegModel()
single_result = model.predict(images[0])  # has type InstResult
batch_result = model.predict_batch(images)  # has type List[BaseResult]

In the code above, I would like for batch_result to have type List[InstResult].

At runtime, none of this matters, and my code executes just fine. But the static type checker (Pylance) in my editor (VS Code) doesn't like how the client code assumes batch_result is the more narrow type. I can only think of these two possible solutions, but neither feels clean to me:

  1. Use the cast function from the typing module
  2. Override predict_batch in the subclasses even though the logic doesn't change

Solution

  • You can use generics and inheritance together to override/narrow an annotation in a parent class

    from typing import List, Generic, TypeVar
    
    T = TypeVar('T')
    
    
    class BaseModel(ABC, Generic[T]):
        def predict_batch(self, images: List) -> List[T]:
            return [self.predict(img) for img in images]
    
        @abstractmethod
        def predict(self, image) -> T:
            raise NotImplementedError()
    
    
    class InstanceSegModel(BaseModel[InstResult]):
        """performs instance segmentation on images"""
    
        @overrides
        def predict(self, image) -> InstResult:
            return InstResult()
    
    
    class SemanticSegModel(BaseModel[SemResult]):
        """performs semantic segmentation on images"""
    
        @overrides
        def predict(self, image) -> SemResult:
            return SemResult()
    
    
    images = [None, None, None]
    
    model = InstanceSegModel()
    single_result = model.predict(images[0])  # has type InstResult
    batch_result = model.predict_batch(images)  # has type List[InstResult]