Search code examples
pythonpandasnumpypython-typing

How to combine type hint using bound type variable and static types for maximum flexibility?


I would like to add type hints to a simple function. Since it internally only uses numpy calls, it is very flexible with its inputs. Basically, it accepts all array-like objects, for which there is the numpy.typing.ArrayLike type.

Defining the return type is not as straight forward however. For some input types like lists, the numpy functions casts the return to a numpy array, which means I could use -> np.NDArray.

Some other input types like pandas.DataFrames, which I use heavily in my code. This would usually be a good reason to use a TypeVar bound to the input type.

How do I keep the flexibility of numpy while also providing meaningful type hints, for example for mypy?

Note: The example method is used to calculate sound pressure levels.

Both code snippets are perfectly viable code, the methods only differ in their type hints.

They cause different errors with static type checkers, demonstrating the limits of each approach.

Version 1:

def get_decibels1(p2: npt.ArrayLike) -> npt.NDArray:
    return (10 * np.log10(np.divide(p2, 4e-10)))

df = pd.DataFrame([[4, 5, 6], [7, 8, 9]])
get_decibels1(df).columns
# --- Causes mypy Error:
# error: "ndarray[Any, dtype[Any]]" has no attribute "columns"  [attr-defined]

Version 2:

T = TypeVar('T', bound=npt.ArrayLike)
def get_decibels2(p2: T) -> T:
    return (10 * np.log10(np.divide(p2, 4e-10)))

ls = [4.0, 5, 6]
get_decibels2(ls).shape
# --- Causes mypy error:
#  error: "list[float]" has no attribute "shape"  [attr-defined]

How do I sensibly combine the two approaches?


Update:

I figured that I maybe can approach this with @overload. But this does not seem to work either, as the signatures overlap.

T = TypeVar('T', bound=Union[pd.DataFrame, pd.Series])
@overload
def get_decibels(p2: T) -> T: ...

@overload
def get_decibels(p2: npt.ArrayLike) -> npt.NDArray: ...

def get_decibels(p2: npt.ArrayLike):
    return (10 * np.log10(np.divide(p2, 4e-10)))
# --- Causes mypy error:
#  error: Overloaded function signatures 1 and 2 overlap with incompatible return types  [overload-overlap]

I was under the impression that mypy just chooses the first signature that matches, which would have solved it. Any ideas on how to resolve this?


Solution

  • Turns out, the update was already 95% there: Mypy and other typechecker do read overloaded functions in the order they are defined, using the first one that matches.

    By default, it will point out overlaps, as they are deemed unsafe (see mypy doc). However:

    Note that in cases where you ignore the overlapping overload error, mypy will usually still infer the types you expect at callsites.

    So using type: ignore[overload-overlap] results in the expected behaviour.

    T = TypeVar('T', bound=Union[pd.DataFrame, pd.Series])
    @overload
    def get_decibels(p2: T) -> T: ...  # type: ignore[overload-overlap]
    
    @overload
    def get_decibels(p2: npt.ArrayLike) -> npt.NDArray: ...
    
    def get_decibels(p2):
        return (10 * np.log10(np.divide(p2, 4e-10)))
    
    ls = [[4, 5, 6], [7, 8, 9]]
    df = pd.DataFrame(ls)
    
    reveal_type(get_decibels(ls))
    >>>note: Revealed type is "numpy.array[Any, numpy.dtype[Any]]"
    
    reveal_type(get_decibels(df))
    >>>note: Revealed type is "pandas.core.frame.DataFrame"
    

    The reason it is deemed unsafe, is a situation where the calling code is already somewhat mislabeled.

    df: np.ArrayLike = pd.DataFrame([[0,1,2],[3,4,5]])
    get_decibels(df)  # mypy will wrongly deduce the return type to be `np.NDArray`
    

    Since I deem this to be somewhat of an edge case, this seems to be a good compromise to add type hints to this method. Especially since pd.DataFrame behaves like a np.NDArray in most cases anyway.