Search code examples
pythonmypypython-typing

Can Mypy overload single objects and unpacked tuples?


The following is easy enough to implement at runtime, but it seems impossible to express in Mypy.

Using the * unpacking (for its nice compactness, e.g. foo(1, 2, ...)) I also want to express the case when there's a single element, because requiring to unpack the single tuple adds a lot of unnecessary indexing. However, it doesn't seem to be possible to disambiguate in any way:

from typing import overload


@overload
def foo(a: int) -> int: # Impossible to distinguish inputs from overload below
    ...


@overload
def foo(*a: int) -> tuple[int, ...]:
    ...


def foo(*a: int | tuple[int, ...]) -> int | tuple[int, ...]:
    if len(a) == 1:
        return a[0]
    return a


assert foo(1) == 1 # This is the expected, but how would the type checker know?
assert foo(1, 2) == (1, 2) # This is obviously the correct signature

Is avoiding the unpacking altogether really the only way?


Solution

  • *args means 0 or more positional arguments, so you need better @overload signatures:

    • If only one argument is passed, return that.
    • If there are at least 2 arguments passed (2 concrete + 0 or more), return a tuple of such arguments.

    Those can be translated to type hints as follow:

    (playgrounds: Mypy, Pyright, PEP 695 syntax)

    from typing import overload, TypeVar, TypeVarTuple
    
    T = TypeVar('T')
    T2 =  TypeVar('T2')
    Ts = TypeVarTuple('Ts')
    
    @overload
    def foo(a: T, /) -> T:
        ...
    
    @overload
    def foo(a0: T, a1: T2, /, *rest: *Ts) -> tuple[T, T2, *Ts]:
        ...
    
    def foo(a: T, /, *rest: *Ts) -> T | tuple[T, *Ts]:
        if len(rest) == 0:
            return a
        
        return (a, *rest)
    
    reveal_type(foo(1))           # mypy & pyright => int
    reveal_type(foo(1, 2))        # mypy & pyright => tuple[int, int]
    reveal_type(foo(1, 2., '3'))  # mypy           => tuple[int, float, Literal['3']]
                                  # pyright        => tuple[int, float, str]
    
    foo()                         # error
    foo(2, bar = 4)               # error