Search code examples
pythonrecursionpython-typing

Typing for tuple or list of tuples in recursive function


Here is an example I wrote.

from typing import Tuple, List


def performances(
    pt: Tuple[float, float] | List[Tuple[float, float]]
) -> float | list[float]:
    if isinstance(pt, list):
        output = []
        for pt_i in pt:
            output.append(performances(pt_i))
    else:
        output = pt[0] + pt[1]
    return output


if __name__ == "__main__":

    # This work as intended
    t1 = performances((1, 2))
    print(t1)

    # This work as intended
    t2 = performances([(1, 2), (3, 4)])
    print(t2)

    # This work while I was expecting it to throw a TypeError
    t3 = performances((1, 2, 3, 4))
    print(t3)

How can I apply type checking in this case to be sure I only accept a list of tuple with two elements?

I also tried without the Tuple and List type from typing, just using list[tuple[float, float]] but the result it the same.

And:

from typing import Tuple, Iterable

type Point = Tuple[float, float]

def performances(pt: Point | Iterable[Point]) -> float | list[float]:
...

But same.

There is probably something I don't understand about typing but I don't know what. Thank you.


Solution

  • The output types have two distinct outcomes based on the parameters. pt: Tuple[float, float] leads to float, whereas pt: List[Tuple[float, float]] leads to List[float].

    Mypy and other static analysis tools have no good way to determine which input shape should lead to which output shape. This leads to errors when calling append(performances(pt_i)), as your static analysis tools will think you are trying to create a List[Union[float, List[float]]]

    A good way to remedy this is to use @overload from the typing module.

    from typing import Tuple, List, overload
    
    
    @overload
    def performances(pt: Tuple[float, float]) -> float:
        ...
    
    
    @overload
    def performances(pt: List[Tuple[float, float]]) -> List[float]:
        ...
    
    
    def performances(
        pt: Tuple[float, float] | List[Tuple[float, float]]
    ) -> float | list[float]:
        if isinstance(pt, list):
            return list(map(performances, pt))
        return pt[0] + pt[1]
    

    Using this strategy, static analysis tools can more easily validate the inputs to their actual outputs. See my earlier comment on your post for why Pytest alone is not enough to validate this scenario.