Search code examples
pythonpython-typingmypypyright

mypy and variance on function returning a list of subclasses based on input parameter


I'm trying to annotate my code properly. Here's an example snippet to recreate my issue:

from pydantic import BaseModel
from typing import TypeVar


class Shape(BaseModel):
   name: str


class Circle(Shape):
   diameter: float


class Square(Shape):
   side: float


class Triangle(Shape):
   base: float
   height: float


T = TypeVar("T")


def generate_shapes(shapes: list[type[T]]) -> list[T]:
   generated_shapes = []
   for shape in shapes:
       if shape == Circle:
           generated_shapes.append(Circle(name=shape.__name__, diameter=10))
       elif shape == Square:
           generated_shapes.append(Square(name=shape.__name__, side=10))
       elif shape == Triangle:
           generated_shapes.append(Triangle(name=shape.__name__, base=10, height=10))
   return generated_shapes


if __name__ == "__main__":
   result = generate_shapes([Circle, Triangle])

Practically the code looks fine, however mypy complains. In particular the error I'm trying to understand is:

Argument 1 to "generate_shapes" has incompatible type "list[ModelMetaclass]"; expected "list[type[Never]]"

My objective would be to be able to infer the output list type from the input. It looks like that using this annotation vscode is able to understand it:

My basic understanding is that mypy complains because lists are mutable and cannot be covariant. So I'm afraid there isn't a way out for what I'm trying to achieve (vscode inferring + mypy happy)?


Solution

  • Your function isn't really generic. Effectively, it expects a list of Shape classes and returns a list of Shapes. There's no real reason to allow a more general list as an argument. If the caller somehow has a list like [Circle, Triangle, 1, Square], they can be responsible for removing non-Shape-subclass elements first.

    Further, the intially inferred type for generated_shapes is list[Any], which gets narrowed to list[Circle] by the first attempt to append anything to the list. If your function isn't generic, you can be explicit about the type of generated_shapes up front. (mypy doesn't go so far as to infer the type of generated_shapes from the declared return type of the function.)

    def generate_shapes(shapes: list[type[Shape]]) -> list[Shape]:
       generated_shapes: list[Shape] = []
       for shape in shapes:
           if shape == Circle:
               generated_shapes.append(Circle(name=shape.__name__, diameter=10))
           elif shape == Square:
               generated_shapes.append(Square(name=shape.__name__, side=10))
           elif shape == Triangle:
               generated_shapes.append(Triangle(name=shape.__name__, base=10, height=10))
       return generated_shapes