Trying to accelerate a DP algorithm on python, numba seemed like an appropriate candidate.
I'm doing a subtraction of a 2D array with a 1D array which delivers a 3D array. I'm then using .argmin()
along the 3rd dimension to obtain a 2D array. This works just fine with numpy, but doesn't with numba.
Toy code reproducing the issue :
from numba import jit
import numpy as np
inflow = np.arange(1,0,-0.01) # Dim [T]
actions = np.arange(0,1,0.05) # Dim [M]
start_lvl = np.random.rand(500).reshape(-1,1)*49 # Dim [Nx1]
disc_lvl = np.arange(0,1000) # Dim [O]
@jit(nopython=True)
def my_func(disc_lvl, actions, start_lvl, inflow):
for i in range(0,100):
# Calculate new level at time i
new_lvl = start_lvl + inflow[i] + actions # Dim [N x M]
# For each new_level element, find closest discretized level
diff = (disc_lvl-new_lvl[:,:,np.newaxis]) # Dim [N x M x O]
idx_lvl = abs(diff).argmin(axis=2) # Dim [N x M]
return True
# function works fine without numba
success = my_func(disc_lvl, actions, start_lvl, inflow)
Why does not the code above run ? It does when taking out @jit(nopython=True)
.
Is there a work round to make the following calculation work with numba ?
I've tried variants with numpy repeats & expand_dims, as well as defining explicitly the input types of the jit function without success.
There are a few things you need to change to make it work:
arr[:, :, None]
: for Numba, it looks like getitem
so prefer using reshape
np.abs
instead of built-in abs
argmin
with axis
keyword argument is not implemented. Prefer using loops, which Numba is designed to optimize.With all this fixed you can run the jitted function:
from numba import jit
import numpy as np
inflow = np.arange(1,0,-0.01) # Dim [T]
actions = np.arange(0,1,0.05) # Dim [M]
start_lvl = np.random.rand(500).reshape(-1,1)*49 # Dim [Nx1]
disc_lvl = np.arange(0,1000) # Dim [O]
@jit(nopython=True)
def my_func(disc_lvl, actions, start_lvl, inflow):
for i in range(0,100):
# Calculate new level at time i
new_lvl = start_lvl + inflow[i] + actions # Dim [N x M]
# For each new_level element, find closest discretized level
new_lvl_3d = new_lvl.reshape(*new_lvl.shape, 1)
diff = np.abs(disc_lvl - new_lvl_3d) # Dim [N x M x O]
idx_lvl = np.empty(new_lvl.shape)
for i in range(diff.shape[0]):
for j in range(diff.shape[1]):
idx_lvl[i, j] = diff[i, j, :].argmin()
return True
# function works fine without numba
success = my_func(disc_lvl, actions, start_lvl, inflow)