I'm trying to annotate my code properly. Here's an example snippet to recreate my issue:
from pydantic import BaseModel
from typing import TypeVar
class Shape(BaseModel):
name: str
class Circle(Shape):
diameter: float
class Square(Shape):
side: float
class Triangle(Shape):
base: float
height: float
T = TypeVar("T")
def generate_shapes(shapes: list[type[T]]) -> list[T]:
generated_shapes = []
for shape in shapes:
if shape == Circle:
generated_shapes.append(Circle(name=shape.__name__, diameter=10))
elif shape == Square:
generated_shapes.append(Square(name=shape.__name__, side=10))
elif shape == Triangle:
generated_shapes.append(Triangle(name=shape.__name__, base=10, height=10))
return generated_shapes
if __name__ == "__main__":
result = generate_shapes([Circle, Triangle])
Practically the code looks fine, however mypy complains. In particular the error I'm trying to understand is:
Argument 1 to "generate_shapes" has incompatible type "list[ModelMetaclass]"; expected "list[type[Never]]"
My objective would be to be able to infer the output list type from the input. It looks like that using this annotation vscode is able to understand it:
My basic understanding is that mypy complains because lists are mutable and cannot be covariant. So I'm afraid there isn't a way out for what I'm trying to achieve (vscode inferring + mypy happy)?
Your function isn't really generic. Effectively, it expects a list of Shape
classes and returns a list of Shape
s. There's no real reason to allow a more general list as an argument. If the caller somehow has a list like [Circle, Triangle, 1, Square]
, they can be responsible for removing non-Shape
-subclass elements first.
Further, the intially inferred type for generated_shapes
is list[Any]
, which gets narrowed to list[Circle]
by the first attempt to append anything to the list. If your function isn't generic, you can be explicit about the type of generated_shapes
up front. (mypy
doesn't go so far as to infer the type of generated_shapes
from the declared return type of the function.)
def generate_shapes(shapes: list[type[Shape]]) -> list[Shape]:
generated_shapes: list[Shape] = []
for shape in shapes:
if shape == Circle:
generated_shapes.append(Circle(name=shape.__name__, diameter=10))
elif shape == Square:
generated_shapes.append(Square(name=shape.__name__, side=10))
elif shape == Triangle:
generated_shapes.append(Triangle(name=shape.__name__, base=10, height=10))
return generated_shapes