Search code examples
pythongenericspython-typingtypingnested-generics

How can I define a TypeAlias for a nested Generic in Python?


I currently have this code

T = TypeVar("T")
Grid = Sequence[Sequence[T]]

def columns(grid: Grid) -> Iterable[list[T]]:
    return ([row[i] for row in grid] for i in range(len(grid[0])))

But I think the T in the alias Grid is bound to a different T in the return type of the function.

How do I define Grid such that I can write

def columns(grid: Grid[T]) -> Iterable[list[T]]:
    ...

I've looked at typing.GenericAlias, but can't see how it helps me.

(I'm aware that Sequence[Sequence[T]] has no guarantee that the grid is actually rectangular, but that's not the problem I want to focus on here.)


Solution

  • When using type variable as a generic parameter, it can be replaced by other type variables, which is mentioned in the Generic Alias Type (but I only found this one):

    The __getitem__() method of generic containers will raise an exception to disallow mistakes like dict[str][str]:

    >>> dict[str][str]
    Traceback (most recent call last):
      File "<stdin>", line 1, in <module>
    TypeError: There are no type variables left in dict[str]
    

    However, such expressions are valid when type variables are used. The index must have as many elements as there are type variable items in the GenericAlias object’s __args__.

    >>> from typing import TypeVar
    >>> Y = TypeVar('Y')
    >>> dict[str, Y][int]
    dict[str, int]
    

    So there is no problem with your current implementation. In the interactive interpreter, you will see:

    >>> from collections.abc import Sequence
    >>> from typing import TypeVar
    >>> T, R = TypeVar('T'), TypeVar('R')
    >>> Grid = Sequence[Sequence[T]]
    >>> Grid
    collections.abc.Sequence[collections.abc.Sequence[~T]]
    >>> Grid[R]
    collections.abc.Sequence[collections.abc.Sequence[~R]]
    

    Mypy will also correctly analyze them:

    from collections.abc import Sequence, Iterable
    from typing import TypeVar
    
    T = TypeVar('T')
    Grid = Sequence[Sequence[T]]
    
    
    def columns(grid: Grid[T]) -> Iterable[list[T]]:
        return ([row[i] for row in grid] for i in range(len(grid[0])))
    
    
    c1: Iterable[list[int]] = columns([[1, 2, 3]])  # pass
    c2: Iterable[list[int]] = columns([[4, 5, '6']])
    # List item 2 has incompatible type "str"; expected "int" (13:42)