Search code examples
pythonnumpypython-typingpep

PEP 586: How does Literal actually help variable return types?


Context

I have just read the PEP 586. In the motivation the authors say the following:

numpy.unique will return either a single array or a tuple containing anywhere from two to four arrays depending on three boolean flag values.

(...)

There is currently no way of expressing the type signatures of these functions: PEP 484 does not include any mechanism for writing signatures where the return type varies depending on the value passed in.

We propose adding Literal types to address these gaps.

But I don't really get how adding Literal types helps with that. And also I wouldn't agree with the statement that

PEP 484 does not include any mechanism for writing signatures where the return type varies depending on the value passed in.

As far as I understand, Union can be used in such a case.

Question

How can the numpy.unique return type be annotated with Literal?


Solution

  • The problem that Literal solves

    You took exactly the right passage from PEP 586. To highlight the two crucial words here one more time, this is about

    signatures where the return type varies depending on the value passed in.

    That is one of the applications for the Literal type. And the statement is in fact correct.

    Could you annotate a function that returns one of two different types under various (not further defined) circumstances before? Sure, as you correctly pointed out, a Union can be used for that.

    Could you annotate a function that returns one of two different types depending on different argument types (or combinations thereof) passed into it? Yes, that is what the @overload decorator is for.

    But annotate a function that returns one of two different types depending on the value of an argument passed into it? This was not possible before Literal.

    To accomplish that, we now use Literal in combination with the @overload decorator. Consider the following example before we get to np.unique.


    Simple example

    Say I have a very silly function double that doubles a float passed as an argument to it. But it can return either a float again or return it as a str, if a special flag is set as well:

    from typing import Union
    
    def double(
        num: float,
        as_string: bool = False,
    ) -> Union[float, str]:
        num *= 2
        if as_string:
            return str(num)
        return num
    

    Now, this annotation is perfectly fine. The return type captures both possible situations, the float and the str being returned.

    But, say now I have another function that accepts only a str:

    def need_str(x: str) -> None:
        print(x.startswith("2"))
    

    What do I do, if I want to pass the output of double as argument to need_str?

    output = double(1.1, as_string=True)
    need_str(output)
    

    This is a problem for a strict type checker. Though the code runs fine because as we know since we pass as_string=True, the output is a string. The static type checker (mypy here) only sees the return type of the first and the parameter type of the second function and rightfully complains:

    error: Argument 1 to "need_str" has incompatible type "Union[float, str]"; expected "str"  [arg-type]
    

    It sees that output could well be a float. It doesn't know what double does inside. How do we fix that? Well, before Literal, the simplest solution I can think of would have been to do something like this:

    output = double(1.1, as_string=True)
    assert isinstance(output, str)
    need_str(output)
    

    That is reasonable, satisfies the type checker and gets the job done.

    But now that we have Literal, we can solve this (arguably) much more elegantly:

    from typing import Literal, Union, overload
    
    @overload
    def double(
        num: float,
        as_string: Literal[False],
    ) -> float: ...
    
    @overload
    def double(
        num: float,
        as_string: Literal[True],
    ) -> str: ...
    
    def double(
        num: float,
        as_string: bool = False,
    ) -> Union[float, str]:
        num *= 2
        if as_string:
            return str(num)
        return num
    

    Now, if I try this again, the type checker understands the specific call to double, infers the returned value to be of type str and considers the next function call to be type safe:

    output = double(1.1, as_string=True)
    need_str(output)
    

    Adding reveal_type(output) makes mypy tell us Revealed type is "builtins.str".

    I hope this illustrates the capabilities this introduces and that they did not exist before. There are other things you can do with Literal, but that is offtopic.


    How this helps np.unique

    As the documentation you linked reveals, np.unique has essentially four different possible return types:

    1. One array of the same dtype as ar
    2. A 2-tuple of one array of the same dtype as ar followed by one integer array
    3. A 3-tuple of one array of the same dtype as ar followed by two integer arrays
    4. A 4-tuple of one array of the same dtype as ar followed by three integer arrays

    Which type it is (as well as the meaning of the values) depends entirely on the values passed to the parameters return_index, return_inverse, and return_counts:

    1. If all those arguments are False (default)
    2. If one of those arguments is True
    3. If two of those arguments are True
    4. If all three of those arguments are True

    Thus, the situation is analogous to the simple example from above. It's just that there are a lot more @overloads to define, since we have 23 = 8 combinations of arguments to reflect in our calls.

    Now, if I had too much time on my hands and wanted to write a useless wrapper around np.unique, I would demonstrate how Literal can be used to properly annotate all different call variations and satisfy even the strictest type checker...

    *sigh*

    A useless wrapper around np.unique

    from collections.abc import Sequence
    from typing import Literal, TypeAlias, TypeVar, Union, overload
    
    import numpy as np
    from numpy.typing import NDArray
    
    T = TypeVar("T", bound=np.generic)
    NPint: TypeAlias = np.int_
    
    # All options `False`:
    @overload
    def np_unique(
        items: Sequence[T],
        *,
        return_index: Literal[False],
        return_inverse: Literal[False],
        return_counts: Literal[False],
    ) -> NDArray[T]: ...
    
    # One option `True`:
    @overload
    def np_unique(
        items: Sequence[T],
        *,
        return_index: Literal[True],
        return_inverse: Literal[False],
        return_counts: Literal[False],
    ) -> tuple[NDArray[T], NDArray[NPint]]: ...
    
    @overload
    def np_unique(
        items: Sequence[T],
        *,
        return_index: Literal[False],
        return_inverse: Literal[True],
        return_counts: Literal[False],
    ) -> tuple[NDArray[T], NDArray[NPint]]: ...
    
    @overload
    def np_unique(
        items: Sequence[T],
        *,
        return_index: Literal[False],
        return_inverse: Literal[False],
        return_counts: Literal[True],
    ) -> tuple[NDArray[T], NDArray[NPint]]: ...
    
    # Two options `True`:
    @overload
    def np_unique(
        items: Sequence[T],
        *,
        return_index: Literal[True],
        return_inverse: Literal[True],
        return_counts: Literal[False],
    ) -> tuple[NDArray[T], NDArray[NPint], NDArray[NPint]]: ...
    
    @overload
    def np_unique(
        items: Sequence[T],
        *,
        return_index: Literal[True],
        return_inverse: Literal[False],
        return_counts: Literal[True],
    ) -> tuple[NDArray[T], NDArray[NPint], NDArray[NPint]]: ...
    
    @overload
    def np_unique(
        items: Sequence[T],
        *,
        return_index: Literal[False],
        return_inverse: Literal[True],
        return_counts: Literal[True],
    ) -> tuple[NDArray[T], NDArray[NPint], NDArray[NPint]]: ...
    
    # Three options `True`:
    @overload
    def np_unique(
        items: Sequence[T],
        *,
        return_index: Literal[True],
        return_inverse: Literal[True],
        return_counts: Literal[True],
    ) -> tuple[NDArray[T], NDArray[NPint], NDArray[NPint], NDArray[NPint]]: ...
    
    def np_unique(
        items: Sequence[T],
        *,
        return_index: Literal[True, False] = False,
        return_inverse: Literal[True, False] = False,
        return_counts: Literal[True, False] = False,
    ) -> Union[
        NDArray[T],
        tuple[NDArray[T], NDArray[NPint]],
        tuple[NDArray[T], NDArray[NPint], NDArray[NPint]],
        tuple[NDArray[T], NDArray[NPint], NDArray[NPint], NDArray[NPint]],
    ]:
        return np.unique(
            np.array(items),
            return_index=return_index,
            return_inverse=return_inverse,
            return_counts=return_counts,
        )
    

    It is worth noting that with such extensive overloading, the possibilities are theoretically much greater. If it so happened that one of the options would produce an array of yet another different dtype of elements, we could still properly annotate that case here.

    It is also worth mentioning that IMHO this goes too far. I don't think this is good style. A function should not have this many fundamentally distinct call signatures. It's what some would call "code smell"...

    But as for the typing capabilities, I say better have it and don't need it than the other way around.


    Hope this helps.