Search code examples
pythonnumpynumba

Numba: how to get indexes for all rows that contain at least one nan values?


Suppose I have a numpy 2d array (m by n), I want to get indexes of all its rows contain at least one nan values.

It is relatively straightforward to do it in pure numpy as follows:

import numpy as np

X = np.array([[1, 2], [3, np.nan], [6, 9]])

has_nan_idx = np.isnan(X).any(axis=1)

has_nan_idx
>>> array([False, True, False])

How can I achieve the same using numba njit? For me, I got an error since numba does not support any with arguments.

Thanks.


Solution

  • If you use guvectorize you'll automatically get the ufunc benefits of having things like the axis keyword.

    For example:

    from numba import guvectorize
    
    @guvectorize(["void(float64[:], boolean[:])"], "(n)->()")
    def isnan_with_axis(x, out):
        
        n = x.size
        out[0] = False
        
        for i in range(n):
            if np.isnan(x[i]):
                out[0] = True
                break
    
    isnan_with_axis(X, axis=1)
    # array([False,  True, False])