Search code examples
pythonnumbajit

calling njit function in python numba jitclass fails


@njit
def cumutrapz(x:np.array, y:np.array):
    return np.append(0, [
        np.trapz(y=y[i-2:i], x=x[i-2:i]) for i in range(2, len(x) + 1)]).cumsum()

from numba import float64
@jitclass([
    ('a', float64[:]),
    ('b', float64[:]),    
    ('c', float64[:]),    
])
class Testaroo(object):
    def __init__(self, a, b):
        self.a = a
        self.b = b
        self.c = np.zeros(len(self.a), dtype=np.float64)
        
    def set_c(self):
        self.c = cumutrapz(self.a, self.b)
        
testaroo = Testaroo(  
    np.arange(50, dtype=np.float64), np.sin(np.arange(50, dtype=np.float64)))
testaroo.set_c()

The above fails, but the following two very similar examples work:

cumutrapz(np.arange(50, dtype=np.float64), np.sin(np.arange(50, dtype=np.float64)))

and

from numba import float64
@jitclass([
    ('a', float64[:]),
    ('b', float64[:]),    
    ('c', float64[:]),    
])
class Testaroo(object):
    def __init__(self, a, b):
        self.a = a
        self.b = b
        self.c = np.zeros(len(self.a), dtype=np.float64)
        
    def set_c(self):
        self.c = (self.a * self.b).cumsum()
        
testaroo = Testaroo(  
   np.arange(50, dtype=np.float64), np.sin(np.arange(50, dtype=np.float64)))
testaroo.set_c()

This latter example will work for me for now, but I'd like to know if there's a way to get the cumutrapz function working inside of a jitclass.

I'm using numba version '0.53.1'.


Solution

  • Carefully reading though the long error message you can find:

    No implementation of function Function(<function trapz at 0x7f7e9b21e5e0>) 
    found for signature:
    >>> trapz(y=array(float64, 1d, A), x=array(float64, 1d, A))
    ...
    reshape() supports contiguous array only
    

    Arrays with format A (any) are not necessarily contiguous.

    You can ensure the function deals with contiguous arrays only:

    @njit([nb.float64[::1](nb.float64[::1], nb.float64[::1])])
    def cumutrapz(x, y):
        ...
    

    Then a new error appears:

    Invalid use of type(CPUDispatcher(<function MyTestCase.test_cumutrapz.<locals>.cumutrapz at 0x7f8e15a841f0>))
    with parameters (array(float64, 1d, A), array(float64, 1d, A))
    Known signatures:
        * (array(float64, 1d, C), array(float64, 1d, C)) -> array(float64, 1d, C)
    ...
        self.c = cumutrapz(self.a, self.b)
        ^
    

    So the arrays in the class are not contiguous.

    In order to ensure they are, you can change the class specification to:

    @jitclass([
        ('a', nb.float64[::1]),
        ('b', nb.float64[::1]),
        ('c', nb.float64[::1]),
        ])
    

    Now it works (tested with Numba 0.54.0).