I'm converting some odd code over to be Numba compatible using parallel=True
. It has a problematic array assignment that I can't quite figure out how to rewrite in a way numba can handle. I try to decode what the error means, but I get quite lost. The only thing clear is it does not like the line: Averaging_price_3D[leg, :, expired_loc] = last_non_expired_values.T
The error is pretty long, included for reference here:
TypingError: No implementation of function Function(<built-in function setitem>) found for signature:
setitem(array(float64, 3d, C), Tuple(int64, slice<a:b>, array(int64, 1d, C)), array(float64, 2d, F))
There are 16 candidate implementations:
- Of which 14 did not match due to:
Overload of function 'setitem': File: <numerous>: Line N/A.
With argument(s): '(array(float64, 3d, C), Tuple(int64, slice<a:b>, array(int64, 1d, C)), array(float64, 2d, F))':
No match.
- Of which 2 did not match due to:
Overload in function 'SetItemBuffer.generic': File: numba\core\typing\arraydecl.py: Line 176.
With argument(s): '(array(float64, 3d, C), Tuple(int64, slice<a:b>, array(int64, 1d, C)), array(float64, 2d, F))':
Rejected as the implementation raised a specific error:
NumbaNotImplementedError: only one advanced index supported
And here is a short code segment to reproduce the error:
import numpy as np
import numba as nb
@nb.jit(nopython=True, parallel=True, nogil=True)
def main(Averaging_price_3D, expired_loc, last_non_expired_values):
for leg in range(Averaging_price_3D.shape[0]):
# line below causes the numba error:
Averaging_price_3D[leg, :, expired_loc] = last_non_expired_values.T
return Averaging_price_3D
if __name__ == "__main__":
Averaging_price_3D=np.random.rand(2,8192,11)*100 # shape (2,8192,11) 3D array float64
expired_loc=np.arange(4,10).astype(np.int64) # shape (6,) 1D array int64
last_non_expired_values = Averaging_price_3D[1,:,0:expired_loc.shape[0]].copy() # shape (8192,6) 2D array float64
result = main(Averaging_price_3D, expired_loc, last_non_expired_values)
Now the best I can interpret this error is that "numba doesn't know how to set values in a 3D matrix using array indexing with values from a 2D array." But I searched online quite a bit and can't find another way to accomplish the same thing, without numba crashing on it.
In other cases like this I resorted to flattening the arrays with a .reshape(-1) before indexing, but I'm having issues with figuring out how to do that in this specific case (that was easy with a 3D array indexed with another 3D array, as they both would flatten in the same order)... Any help is appreciated!
Well interesting enough, I looked at the indexes passed to the 3D array (since the error said "only one advanced index supported," I chose to examine my indexing):
3Darray[int, :, 1Darray]
Seeing numba is quite picky, I tried rewriting it a little bit, so that a 1D array wasn't used as an index (apparently, this is an "advanced index", so use an int
index instead). Reading numba errors and solutions, they tend to add loops, so I tried that here. So instead of passing a 1D array as an index, I looped over the elements of the 1D array:
import numpy as np
import numba as nb
@nb.jit(cache=True, nopython=True, parallel=True, nogil=True)
def main(Averaging_price_3D, expired_loc, last_non_expired_values):
for leg in nb.prange(Averaging_price_3D.shape[0]):
# change the indexing for numba to get rid of a 1D array index
for i in nb.prange(expired_loc.shape[0]):
# now we assign values 3Darray[int,:,int] = 1Darray
Averaging_price_3D[leg, :, expired_loc[i]] = last_non_expired_values[:,i].T
return Averaging_price_3D
if __name__ == "__main__":
Averaging_price_3D=np.random.rand(2,8192,11)*100 # shape (2,8192,11) 3D array float64
expired_loc=np.arange(4,10).astype(np.int64) # shape (6,) 1D array int64
last_non_expired_values = Averaging_price_3D[1,:,0:expired_loc.shape[0]] # shape (8192,6) 2D array float64
result = main(Averaging_price_3D, expired_loc, last_non_expired_values)
Now it works no problem at all. So it appears to me if you want to access elements from a 3D array with numba, you should do it with either ints
or :
. It appears to not like 1D array indexing, so replace it with a loop and it should run in parallel.