Search code examples
pythonnumpynumba

Have float64 or float32 attribute in numba jitclass


How to have a numba jitclass with an argument which can be either a float64 or a float32 ? With functions, the following code works:

import numba
import numpy as np
from numba import njit
from numba.experimental import jitclass


@njit()
def f(a):
    print(a.dtype)
    return a[0]


a = np.zeros(3)
f(a)
f(a.astype(np.float32))

while trying to use both float32 and float64 with class attributes fails:

@jitclass([('arr', numba.types.float64[:])])
class MyClass():
    def __init__(self):
        pass

    def f(self, a):
        self.arr = a


myclass = MyClass()
myclass.f(np.zeros(3))
# following line fails:
myclass.f(np.zeros(3, dtype=np.float32))

Is there a workaround ?


Solution

  • When you call MyClass(), Numba need to instantiate a class and because Numba only work with well-defined strongly types (this is what makes it fast and so useful), the field of the class need to be typed before the instantiation of an object. Thus, you cannot define the type of MyClass fields when the method f is called because this call is made by the CPython interpreter which is dynamic. Note that a class usually have more than one method (otherwise such a class would not be very useful) and this is why partial compilation is not really possible either.

    One simple solution to address this problem is simply to use two types:

    class MyClass():
        def __init__(self):
            pass
    
        def f(self, a):
            self.arr = a
    
    MyClass_float32 = jitclass([('arr', numba.types.float32[:])])(MyClass)
    MyClass_float64 = jitclass([('arr', numba.types.float64[:])])(MyClass)
    
    myclass = MyClass_float32() # Instantiate the class lazily and an object
    # `self.arr` is already instantiated here and it has `float32[:]` type.
    myclass.f(np.zeros(3, dtype=np.float32))
    
    myclass = MyClass_float64()
    myclass.f(np.zeros(3, dtype=np.float64))