how to use index of numpy.array in numba.njit()? In the following, if numba.njit is used, the code will exit with an error. I find that the error dues to "b = a[idx]". But in fact, that should be right in python. How to correct that in numba? thx
@numba.njit()
def test(a):
idx = np.where(a>5)
b = a[idx]
return b
a = np.linspace(0,15,16).reshape([4,4])
b = test(a)
The docs say that a subset of advanced indexing is also supported: only one advanced index is allowed, and it has to be a one-dimensional array.
And if you run your code without numba, you can see that the result is an 1D array anyway:
>>> a[np.where(a > 5)]
array([ 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.])
So you can operate on an 1D array directly:
@nb.njit()
def test(a):
a = a.ravel()
idx = np.where(a > 5)
b = a[idx]
return b
Or even simpler:
@nb.njit()
def test(a):
a = a.ravel()
return a[a > 5]