Suppose I have a numpy 2d array (m by n), I want to get indexes of all its rows contain at least one nan values.
It is relatively straightforward to do it in pure numpy as follows:
import numpy as np
X = np.array([[1, 2], [3, np.nan], [6, 9]])
has_nan_idx = np.isnan(X).any(axis=1)
has_nan_idx
>>> array([False, True, False])
How can I achieve the same using numba njit? For me, I got an error since numba does not support any with arguments.
Thanks.
If you use guvectorize
you'll automatically get the ufunc benefits of having things like the axis keyword.
For example:
from numba import guvectorize
@guvectorize(["void(float64[:], boolean[:])"], "(n)->()")
def isnan_with_axis(x, out):
n = x.size
out[0] = False
for i in range(n):
if np.isnan(x[i]):
out[0] = True
break
isnan_with_axis(X, axis=1)
# array([False, True, False])