Search code examples
pythonnumpynumpy-ndarraymypy

How to use mypy to ensure that a NumPy array of floats is passed as function argument?


Can mypy check that a NumPy array of floats is passed as a function argument? For the code below mypy is silent when an array of integers or booleans is passed.

import numpy as np
import numpy.typing as npt

def half(x: npt.NDArray[np.cfloat]):
    return x/2

print(half(np.full(4,2.1)))
print(half(np.full(4,6)))    # want mypy to complain about this
print(half(np.full(4,True))) # want mypy to complain about this

Solution

  • Mypy can check the type of values passed as function arguments, but it currently has limited support for NumPy arrays. You can use the numpy.typing.NDArray type hint, as in your code, to specify that the half function takes a NumPy array of complex floats as an argument. However, mypy will not raise an error if an array of integers or booleans is passed, as it currently cannot perform type-checking on the elements of the array. To ensure that only arrays of complex floats are passed to the half function, you will need to write additional runtime checks within the function to validate the input.