Search code examples
pythonnumpynumba

numba.jit can’t compile np.roll


I’m trying to compile the "foo" function using jit

import numpy as np
from numba import jit

dy = 5
@jit
def foo(grid):
    return np.sum([np.roll(np.roll(grid, y, axis = 1), x, axis = 0)
                   for x in (-1, 0, 1) for y in (-1, 0, 1) if x or y], axis=0)


ex_grid = np.random.rand(5,5)>0.5
result = foo(ex_grid)

And I get the following error:

Compilation is falling back to object mode WITH looplifting enabled because Function "foo" failed type inference due to: Invalid use of Function(<function roll at 0x00000161E45C7D90>) with argument(s) of type(s): (array(bool, 2d, C), Literal[int](5), axis=Literal[int](1))
 * parameterized
In definition 0:
    TypeError: np_roll() got an unexpected keyword argument 'axis'

The function works, but the compilation fails.

How can I fix this error, Is np.roll compatible with numba, and if not, is there any alternative?


Solution

  • If you check the docs you'll see that for np.roll only the two first arguments are supported, hence it will only perform the rolling on a flattened array (since you cannot specify an axis).

    numpy.roll() (only the 2 first arguments; second argument shift must be an integer)

    Note however that it does not really make sense to use numba here, since you're performing a single vectorized operation, which will already run very fast. Numba would only make sense if you had to loop over the array to apply some logic.

    So the only possible way to roll the rows of your array here using numba would be to loop over them:

    @njit
    def foo(a, dy):
        out = np.empty(a.shape, np.int32)
        for i in range(a.shape[0]):
            out[i] = np.roll(a[i], dy)
        return out
    
    np.allclose(foo(ex_grid, 3).astype(bool), np.roll(ex_grid, 3, axis=1))
    # True
    

    Though as mentioned, this will be much slower than simply using np.roll setting axis=1, since this is already vectorized and all looping is done on C level:

    ex_grid = np.random.rand(5000,5000)>0.5
    
    %timeit foo(ex_grid, 3)
    # 111 ms ± 820 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
    
    %timeit np.roll(ex_grid, 1, axis=1)
    # 13.8 ms ± 127 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)