Search code examples
pythonnumba

numba njit give my and error on a 2D np.array indexation


I'm trying to index a 2D matrix B in a njit function with a vector containing the index I want a, a slice of matrix D here a minimal example:

import numba as nb
import numpy as np

@nb.njit()
def test(N,P,B,D):
    for i in range(N):
        a = D[i,:]
        b =  B[i,a]
        P[:,i] =b

P = np.zeros((5,5))
B = np.random.random((5,5))*100
D = (np.random.random((5,5))*5).astype(np.int32)
print(D)
N = 5
print(P)
test(N,P,B,D)
print(P)

I get an error of numba at the line b = B[i,a]

File "dj.py", line 10:
def test(N,P,B,D):
    <source elided>
        a = D[i,:]
        b =  B[i,a]
        ^

This is not usually a problem with Numba itself but instead often caused by
the use of unsupported features or an issue in resolving types.

I don't understand what AM I doing wrong here. The code works without the @nb.njit() decorator


Solution

  • numba doesn't support all the same "fancy-indexing" that numpy does - in this case the issue is selecting array elements with the a array.

    For your particular case, because you know the shape of b in advance, you could workaround like this:

    import numba as nb
    import numpy as np
    
    @nb.njit
    def test(N,P,B,D):
        b = np.empty(D.shape[1], dtype=B.dtype)
    
        for i in range(N):
            a = D[i,:]
            for j in range(a.shape[0]):
                b[j] = B[i, j]
            P[:, i] = b