Search code examples
pythonnumpynumpy-ndarraynumba

how to use index of numpy.array in numba.njit()?


how to use index of numpy.array in numba.njit()? In the following, if numba.njit is used, the code will exit with an error. I find that the error dues to "b = a[idx]". But in fact, that should be right in python. How to correct that in numba? thx

@numba.njit()
def test(a):
    idx = np.where(a>5)
    b   = a[idx]
    return b

a = np.linspace(0,15,16).reshape([4,4])
b = test(a)

Solution

  • The docs say that a subset of advanced indexing is also supported: only one advanced index is allowed, and it has to be a one-dimensional array.

    And if you run your code without numba, you can see that the result is an 1D array anyway:

    >>> a[np.where(a > 5)]
    array([ 6.,  7.,  8.,  9., 10., 11., 12., 13., 14., 15.])
    

    So you can operate on an 1D array directly:

    @nb.njit()
    def test(a):
        a = a.ravel()
        idx = np.where(a > 5)
        b = a[idx]
        return b
    

    Or even simpler:

    @nb.njit()
    def test(a):
        a = a.ravel()
        return a[a > 5]