Search code examples
pythonmypy

Is List truly invariant in Python


I have a script toy.py, and I find its behavior kind of confusing.

from typing import List


class A:
    a = 1

class B(A):
    b = 2


def func(input_arg: List[A]) -> None:
    """Debugging mypy."""
    print(f"{input_arg=:}")
    print(f"{type(input_arg)=:}")

If I append

if __name__ == "__main__":
    arg1 = B()
    reveal_type([arg1])
    func([arg1])

mypy passes:

mypy toy.py
toy.py:22:17: note: Revealed type is "builtins.list[toy.B*]"

but if I instead append

if __name__ == "__main__":
    arg2 = [B()]
    reveal_type(arg2)
    func(arg2)

which I thought is equivalent to the first case, I see error

mypy toy.py
toy.py:26:17: note: Revealed type is "builtins.list[toy.B*]"
toy.py:27:10: error: Argument 1 to "func" has incompatible type "List[B]"; expected "List[A]"
toy.py:27:10: note: "List" is invariant -- see https://mypy.readthedocs.io/en/stable/common_issues.html#variance
toy.py:27:10: note: Consider using "Sequence" instead, which is covariant
Found 1 error in 1 file (checked 1 source file)

If List is invariant, why would the first case pass?

mypy --version
mypy 0.931

Solution

  • This behavior has to do with the way mypy uses context to infer types. When calling a function, the type hints for parameters in the function's definition can be used to infer the type of passed arguments.

    However, mypy only allows this 'context-based inference' within a single statement. The following examples from the mypy documentation illustrates a more extreme case:

    This is allowed, and uses single-statement context to infer the type of an empty list as list[int]:

    def foo(arg: list[int]) -> None:
       print('Items:', ''.join(str(a) for a in arg))
    
    foo([])  # OK
    

    but here, the context would need to filter up from the statement foo(a) up to the assignment a = [], similar to your second example, but mypy isn't able to do that.

    def foo(arg: list[int]) -> None:
       print('Items:', ''.join(str(a) for a in arg))
    
    a = []  # Error: Need type annotation for "a"
    foo(a)
    

    Interestingly, using an assignment expression also doesn't work: the assignment statement is completed before the function is called, so no context can be passed:

    def foo(arg: list[int]) -> None:
       print('Items:', ''.join(str(a) for a in arg))
    
    foo(a := [1.1])  # Error: "a" has incompatible type "list[float]"