Search code examples
pythonnumba

Numba Exception: Cannot determine Numba type of <class 'type'>


I want to convert a function to Numba for performance reasons. My MWE example is below. If I remove the @njit decorator then, the code works but with @njit, I am getting a runtime exception. The exception is most likely coming because of the dtype=object to define the result_arr but I tried using dtype=float64 also, but I get similar exception.

import numpy as np
from numba import njit
from timeit import timeit

######-----------Required NUMBA function----------###
#@njit #<----without this, the code works
def required_numba_function():
    nRows = 151
    nCols = 151
    nFrames = 24
    result_arr = np.empty((151* 151 * 24), dtype=object)

    for frame in range(nFrames):
        for row in range(nRows):
            for col in range(nCols):        
                size_rows = np.random.randint(8, 15) 
                size_cols = np.random.randint(2, 6)            
                args = np.random.normal(3, 2.5, size=(size_rows, size_cols)) # size is random
                flat_idx = frame * (nRows * nCols) + (row * nCols + col)
                result_arr[flat_idx] = args

    return result_arr

######------------------main()-------##################
if __name__ == "__main__":
    required_numba_function()

    print() 

How can I resolve the Numba exception?


Solution

  • As you say that a list of arrays is fine, you can just replace the assignment of result_array with an empty array of dtype=object with an empty list that you append to at each iteration - this is compatible with numba:

    @nb.njit
    def required_numba_function2():
        np.random.seed(0) # Just for testing, you seem to have to set the seed within the function for numba to be aware of it
        nRows = 151
        nCols = 151
        nFrames = 24
        result_arr = []
    
        for frame in range(nFrames):
            for row in range(nRows):
                for col in range(nCols):        
                    size_rows = np.random.randint(8, 15) 
                    size_cols = np.random.randint(2, 6)            
                    args = np.random.normal(3, 2.5, size=(size_rows, size_cols)) # size is random
                    result_arr.append(args)
    
        return result_arr
    

    Test

    np.random.seed(0)
    result = required_numba_function()
    result2 = required_numba_function2()
    for i, j in zip(result, result2):
        assert np.allclose(i, j)
    

    Time:

    %timeit required_numba_function()
    %timeit required_numba_function2()
    
    2.08 s ± 37.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    606 ms ± 3.43 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)