Search code examples
pythonnumpynumba

numba np.diff with axis=0


While using numba, axis=0 is acceptable parameters for np.sum(), but not with np.diff(). Why is this happening? I'm working with 2D, thus axis specification is needed.

@jit(nopython=True)
def jitsum(y):
    np.sum(y, axis=0)

@jit(nopython=True)
def jitdiff(y): #this one will cause error
    np.diff(y, axis=0)

Error: np_diff_impl() got an unexpected keyword argument 'axis'

A workaround in 2D will be:

@jit(nopython=True)
def jitdiff(y):
    np.diff(y.T).T

Solution

  • np.diff on a 2D array with n=1, axis=1 is just

    a[:, 1:] - a[:, :-1]
    

    For axis=0:

    a[1:, :] - a[:-1, :]
    

    I suspect that the lines above will compile just fine with numba.