Search code examples
pythonnumpymultidimensional-arraysubtractionnumba

Numba and multidimensions additions - not working with numpy.newaxis?


Trying to accelerate a DP algorithm on python, numba seemed like an appropriate candidate.

I'm doing a subtraction of a 2D array with a 1D array which delivers a 3D array. I'm then using .argmin() along the 3rd dimension to obtain a 2D array. This works just fine with numpy, but doesn't with numba.

Toy code reproducing the issue :

from numba import jit
import numpy as np

inflow      = np.arange(1,0,-0.01)                  # Dim [T]
actions     = np.arange(0,1,0.05)                   # Dim [M]
start_lvl   = np.random.rand(500).reshape(-1,1)*49  # Dim [Nx1]
disc_lvl    = np.arange(0,1000)                     # Dim [O]

@jit(nopython=True)
def my_func(disc_lvl, actions, start_lvl, inflow):
    for i in range(0,100):
        # Calculate new level at time i
        new_lvl = start_lvl + inflow[i] + actions       # Dim [N x M]

        # For each new_level element, find closest discretized level
        diff    = (disc_lvl-new_lvl[:,:,np.newaxis])    # Dim [N x M x O]
        idx_lvl = abs(diff).argmin(axis=2)              # Dim [N x M]

        return True

# function works fine without numba
success = my_func(disc_lvl, actions, start_lvl, inflow)

Why does not the code above run ? It does when taking out @jit(nopython=True). Is there a work round to make the following calculation work with numba ?

I've tried variants with numpy repeats & expand_dims, as well as defining explicitly the input types of the jit function without success.


Solution

  • There are a few things you need to change to make it work:

    1. Adding a dimension with arr[:, :, None]: for Numba, it looks like getitem so prefer using reshape
    2. Use np.abs instead of built-in abs
    3. The argmin with axis keyword argument is not implemented. Prefer using loops, which Numba is designed to optimize.

    With all this fixed you can run the jitted function:

    from numba import jit
    import numpy as np
    
    inflow = np.arange(1,0,-0.01)  # Dim [T]
    actions = np.arange(0,1,0.05)  # Dim [M]
    start_lvl = np.random.rand(500).reshape(-1,1)*49  # Dim [Nx1]
    disc_lvl = np.arange(0,1000)  # Dim [O]
    
    @jit(nopython=True)
    def my_func(disc_lvl, actions, start_lvl, inflow):
        for i in range(0,100):
            # Calculate new level at time i
            new_lvl = start_lvl + inflow[i] + actions  # Dim [N x M]
    
            # For each new_level element, find closest discretized level
            new_lvl_3d = new_lvl.reshape(*new_lvl.shape, 1)
            diff = np.abs(disc_lvl - new_lvl_3d)  # Dim [N x M x O]
    
            idx_lvl = np.empty(new_lvl.shape)
            for i in range(diff.shape[0]):
                for j in range(diff.shape[1]):
                    idx_lvl[i, j] = diff[i, j, :].argmin()
    
            return True
    
    # function works fine without numba
    success = my_func(disc_lvl, actions, start_lvl, inflow)