Search code examples
pythonclassjitnumba

How make a python class jitclass compatible when it contains itself jitclass classes?


I'm trying to make a class that could be a part of jitclass but has some attribute that are them-self jitclass objects.

For example, if I have two class with the decorator @jitclass, I would like instanced those in a third class (combined).

import numpy as np
from numba import jitclass
from numba import boolean, int32, float64,uint8

spec = [
    ('type' ,int32),
    ('val' ,float64[:]),
    ('result',float64)]

@jitclass(spec)
class First:
    def __init__(self):
        self.type = 1
        self.val = np.ones(100)
        self.result = 0.
    def sum(self):
        self.result = np.sum(self.val)

@jitclass(spec)
class Second:
    def __init__(self):
        self.type = 2
        self.val = np.ones(100)
        self.result = 0.
    def sum(self):
        self.result = np.sum(self.val)



@jitclass(spec)
class Combined:
    def __init__(self):
        self.List = []
        for i in range(10):
            self.List.append(First())
            self.List.append(Second())

    def sum(self):
        for i, c in enumerate(self.List):
            c.sum()
    def getresult(self):
        result = []
        for i, c in enumerate(self.List):
            result.append(c.result)
        return result


C = Combined()
C.sum()
result = C.getresult()
print(result)

In that example I get an error because numba cannot determine the type of self.List which is a combination of the two jitclasses.

How can I make the class Combined be jitclass compatible?

Update

It tried something I found elsewhere:

import numpy as np
from numba import jitclass, deferred_type
from numba import boolean, int32, float64,uint8
from numba.typed import List

spec = [
    ('type' ,int32),
    ('val' ,float64[:]),
    ('result',float64)]

@jitclass(spec)
class First:
    def __init__(self):
        self.type = 1
        self.val = np.ones(100)
        self.result = 0.
    def sum(self):
        self.result = np.sum(self.val)
 

 
spec1 = [('ListA',  List(First.class_type.instance_type, reflected=True))]

@jitclass(spec1)
class Combined:
    def __init__(self):
        self.ListA = [First(),First()] 

    def sum(self):
        for i, c in enumerate(self.ListA):
            c.sum()
    def getresult(self):
        result = []
        for i, c in enumerate(self.ListA):
            result.append(c.result)
        return result


C = Combined()
C.sum()
result = C.getresult()
print(result)

But I get this error

List(First.class_type.instance_type)
TypeError: __init__() takes 1 positional argument but 2 were given

Solution

  • TL;DR:

    • You can reference other jitclasses in a jitclass even if you have a list of those. You just need to correct the namespace numba.typed -> numba.types.
    • It's currently (as of numba 0.46) not possible to have heterogeneous lists in jitclasses or no-python numba.jit functions. So you cannot append both instances of First and Second in the same list.

    Solving the numba.typed.List exception

    Your update was almost correct. You need to use numba.types.List not numba.typed.List. The difference is a bit subtle but the numba.types contains types for signatures while the numba.typed namespace contains classes that can be instantiated and used in code.

    So it will work if you use:

    spec1 = [('ListA',  nb.types.List(First.class_type.instance_type, reflected=True))]
    

    With that change this code:

    import numpy as np
    import numba as nb
    
    spec = [
        ('type', nb.int32),
        ('val', nb.float64[:]),
        ('result', nb.float64)
    ]
    
    @nb.jitclass(spec)
    class First:
        def __init__(self):
            self.type = 1
            self.val = np.ones(100)
            self.result = 0.
        def sum(self):
            self.result = np.sum(self.val)
    
    spec1 = [('ListA',  nb.types.List(First.class_type.instance_type, reflected=True))]
    
    @nb.jitclass(spec1)
    class Combined:
        def __init__(self):
            self.ListA = [First(), First()] 
        def sum(self):
            for i, c in enumerate(self.ListA):
                c.sum()
        def getresult(self):
            result = []
            for i, c in enumerate(self.ListA):
                result.append(c.result)
            return result
    
    C = Combined()
    C.sum()
    result = C.getresult()
    print(result)
    

    produces the output: [100.0, 100.0].

    Intermezzo: Does it make sense to use jitclass here?

    However something to keep in mind here is that normal Python classes will probably be faster than the jitclass-approach (or as fast):

    import numpy as np
    import numba as nb
    
    class First:
        def __init__(self):
            self.type = 1
            self.val = np.ones(100)
            self.result = 0.
        def sum(self):
            self.result = np.sum(self.val)
    
    class Combined:
        def __init__(self):
            self.ListA = [First(), First()] 
        def sum(self):
            for i, c in enumerate(self.ListA):
                c.sum()
        def getresult(self):
            result = []
            for i, c in enumerate(self.ListA):
                result.append(c.result)
            return result
    
    C = Combined()
    C.sum()
    C.getresult()
    

    That's no problem if this is just for curiosity. But for production I would start with pure Python+NumPy and only apply numba if it's too slow and then only on the parts that are the bottleneck and only if numba is good at optimizing these things (numba is specialized tool at the moment, not a general purpose tool).

    Heterogeneous (mixed-typed) lists with numba?

    With numba in no-python (no-object) mode you need homogeneous lists. As far as I know numba 0.46 doesn't support lists containing different kinds of objects in jitclasses or nopython-jit methods. That means you cannot have one list containing First and Second instances.

    So this cannot work:

    self.List.append(First())
    self.List.append(Second())
    

    From the numba docs:

    Creating and returning lists from JIT-compiled functions is supported, as well as all methods and operations. Lists must be strictly homogeneous: Numba will reject any list containing objects of different types, even if the types are compatible [...]