I want to convert a function to Numba for performance reasons. My MWE example is below. If I remove the @njit
decorator then, the code works but with @njit
, I am getting a runtime exception. The exception is most likely coming because of the dtype=object
to define the result_arr
but I tried using dtype=float64
also, but I get similar exception.
import numpy as np
from numba import njit
from timeit import timeit
######-----------Required NUMBA function----------###
#@njit #<----without this, the code works
def required_numba_function():
nRows = 151
nCols = 151
nFrames = 24
result_arr = np.empty((151* 151 * 24), dtype=object)
for frame in range(nFrames):
for row in range(nRows):
for col in range(nCols):
size_rows = np.random.randint(8, 15)
size_cols = np.random.randint(2, 6)
args = np.random.normal(3, 2.5, size=(size_rows, size_cols)) # size is random
flat_idx = frame * (nRows * nCols) + (row * nCols + col)
result_arr[flat_idx] = args
return result_arr
######------------------main()-------##################
if __name__ == "__main__":
required_numba_function()
print()
How can I resolve the Numba exception?
As you say that a list of arrays is fine, you can just replace the assignment of result_array
with an empty array of dtype=object
with an empty list that you append to at each iteration - this is compatible with numba:
@nb.njit
def required_numba_function2():
np.random.seed(0) # Just for testing, you seem to have to set the seed within the function for numba to be aware of it
nRows = 151
nCols = 151
nFrames = 24
result_arr = []
for frame in range(nFrames):
for row in range(nRows):
for col in range(nCols):
size_rows = np.random.randint(8, 15)
size_cols = np.random.randint(2, 6)
args = np.random.normal(3, 2.5, size=(size_rows, size_cols)) # size is random
result_arr.append(args)
return result_arr
Test
np.random.seed(0)
result = required_numba_function()
result2 = required_numba_function2()
for i, j in zip(result, result2):
assert np.allclose(i, j)
Time:
%timeit required_numba_function()
%timeit required_numba_function2()
2.08 s ± 37.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
606 ms ± 3.43 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)