Search code examples
pythonnumpynumba

Unsupported array index type when using numba


I would like to use numba to expedite my code (see MWE below). However, I face NumbaTypeError: unsupported array index type. What would be the problem & solution?

import numpy as np
import numba as nb

a = np.array([4, 5, 6, 7, 8, 9], dtype=np.int16)
b = np.array([ 2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19], dtype=np.int16)
c = np.zeros((14, 20, 2), dtype=np.int16)

@nb.njit(fastmath=True)
def printNumbers(a, b, c):
    d = c[a.reshape((a.size, 1)), b, :]

    print(d)

printNumbers(a, b, c)

Solution

  • Although numba supports reshape function, I removed this function and modified the MWE code as follows:

    import numpy as np
    import numba as nb
    
    a = np.array([4, 5, 6, 7, 8, 9], dtype=np.int16)
    b = np.array([ 2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19], dtype=np.int16)
    c = np.zeros((14, 20, 2), dtype=np.int16)
    
    @nb.njit(fastmath=True)
    def printNumbers(a, b, c):
        d = c[a][:, b]
    
        print(d)
    
    printNumbers(a, b, c)