Search code examples
pythonclassnumpynumba

Determining input argument type for jitclass method


I'm working on a jitclass in which one of the methods can accept an input argument of int, float, or numpy.ndarray. I need to be able to determine if the argument is and array or any of the other two types. I've tried using isinstance as shown in the interp method below:

spec = [('x', float64[:]),
        ('y', float64[:])]


@jitclass(spec)
class Lookup:
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def interp(self, x0):
        if isinstance(x0, (float, int)):
            result = self._interpolate(x0)
        elif isinstance(x0, np.ndarray):
            result = np.zeros(x0.size)
            for i in range(x0.size):
                result[i] = self._interpolate(x0[i])
        else:
            raise TypeError("`interp` method can only accept types of float, int, or ndarray.")
        return result

    def _interpolate(self, x0):
        x = self.x
        y = self.y
        if x0 < x[0]:
            return y[0]
        elif x0 > x[-1]:
            return y[-1]
        else:
            for i in range(len(x) - 1):
                if x[i] <= x0 <= x[i + 1]:
                    x1, x2 = x[i], x[i + 1]
                    y1, y2 = y[i], y[i + 1]

                    return y1 + (y2 - y1) / (x2 - x1) * (x0 - x1)

But I get the following error:

numba.errors.TypingError: Failed at nopython (nopython frontend)
Failed at nopython (nopython frontend)
Untyped global name 'isinstance': cannot determine Numba type of <class 'builtin_function_or_method'>
File "Lookups.py", line 17
[1] During: resolving callee type: BoundFunction((<class 'numba.types.misc.ClassInstanceType'>, 'interp') for instance.jitclass.Lookup#2167664ca28<x:array(float64, 1d, A),y:array(float64, 1d, A)>)
[2] During: typing of call at <string> (3)

Is there a way to determine whether an input argument is of a certain type when using jitclasses or in nopython mode?

Edit

I should have mentioned this before but using the type built-in also does not seem to work. For example if I replace the interp method with:

def interp(self, x0):
        if type(x0) == float or type(x0) == int:
            result = self._interpolate(x0)
        elif type(x0) == np.ndarray:
            result = np.zeros(x0.size)
            for i in range(x0.size):
                result[i] = self._interpolate(x0[i])
        else:
            raise TypeError("`interp` method can only accept types of float, int, or ndarray.")
        return result

I get the following error:

numba.errors.TypingError: Failed at nopython (nopython frontend)
Failed at nopython (nopython frontend)
Invalid usage of == with parameters (class(int64), Function(<class 'float'>))

Which I think is referring to the comparison of python float and numba's int64 when I do something like lookup_object.interp(370) for example.


Solution

  • You're out of luck if you need to determine and compare the type inside a numba jitclass or nopython jit function because isinstance isn't supported at all and type supports only on a few numeric types and namedtuples (note that this just returns the type - it's not suitable for comparisons - because == isn't implemented for classes inside numba functions).

    As of Numba 0.35 the only supported built-ins are (source: numba documentation):

    The following built-in functions are supported:

    abs()
    bool
    complex
    divmod()
    enumerate()
    float
    int: only the one-argument form
    iter(): only the one-argument form
    len()
    min()
    max()
    next(): only the one-argument form
    print(): only numbers and strings; no file or sep argument
    range: semantics are similar to those of Python 3 even in Python 2: a range object is returned instead of an array of values.
    round()
    sorted(): the key argument is not supported
    type(): only the one-argument form, and only on some types (e.g. numbers and named tuples)
    zip()
    

    My suggestion: Use a normal Python class and determine the type there and then forward to numba.njitted functions accordingly:

    import numba as nb
    import numpy as np
    
    @nb.njit
    def _interpolate_one(x, y, x0):
        if x0 < x[0]:
            return y[0]
        elif x0 > x[-1]:
            return y[-1]
        else:
            for i in range(len(x) - 1):
                if x[i] <= x0 <= x[i + 1]:
                    x1, x2 = x[i], x[i + 1]
                    y1, y2 = y[i], y[i + 1]
    
                    return y1 + (y2 - y1) / (x2 - x1) * (x0 - x1)
    
    @nb.njit
    def _interpolate_many(x, y, x0):
        result = np.zeros(x0.size, dtype=np.float_)
        for i in range(x0.size):
            result[i] = _interpolate_one(x, y, x0[i])
        return result
    
    class Lookup:
        def __init__(self, x, y):
            self.x = x
            self.y = y
    
        def interp(self, x0):
            if isinstance(x0, (float, int)):
                result = _interpolate_one(self.x, self.y, x0)
            elif isinstance(x0, np.ndarray):
                result = _interpolate_many(self.x, self.y, x0)
            else:
                raise TypeError("`interp` method can only accept types of float, int, or ndarray.")
            return result