I'm trying to make a class that could be a part of jitclass
but has some attribute that are them-self jitclass
objects.
For example, if I have two class with the decorator @jitclass
, I would like instanced those in a third class (combined
).
import numpy as np
from numba import jitclass
from numba import boolean, int32, float64,uint8
spec = [
('type' ,int32),
('val' ,float64[:]),
('result',float64)]
@jitclass(spec)
class First:
def __init__(self):
self.type = 1
self.val = np.ones(100)
self.result = 0.
def sum(self):
self.result = np.sum(self.val)
@jitclass(spec)
class Second:
def __init__(self):
self.type = 2
self.val = np.ones(100)
self.result = 0.
def sum(self):
self.result = np.sum(self.val)
@jitclass(spec)
class Combined:
def __init__(self):
self.List = []
for i in range(10):
self.List.append(First())
self.List.append(Second())
def sum(self):
for i, c in enumerate(self.List):
c.sum()
def getresult(self):
result = []
for i, c in enumerate(self.List):
result.append(c.result)
return result
C = Combined()
C.sum()
result = C.getresult()
print(result)
In that example I get an error because numba
cannot determine the type of self.List
which is a combination of the two jitclasses.
How can I make the class Combined
be jitclass
compatible?
It tried something I found elsewhere:
import numpy as np
from numba import jitclass, deferred_type
from numba import boolean, int32, float64,uint8
from numba.typed import List
spec = [
('type' ,int32),
('val' ,float64[:]),
('result',float64)]
@jitclass(spec)
class First:
def __init__(self):
self.type = 1
self.val = np.ones(100)
self.result = 0.
def sum(self):
self.result = np.sum(self.val)
spec1 = [('ListA', List(First.class_type.instance_type, reflected=True))]
@jitclass(spec1)
class Combined:
def __init__(self):
self.ListA = [First(),First()]
def sum(self):
for i, c in enumerate(self.ListA):
c.sum()
def getresult(self):
result = []
for i, c in enumerate(self.ListA):
result.append(c.result)
return result
C = Combined()
C.sum()
result = C.getresult()
print(result)
But I get this error
List(First.class_type.instance_type)
TypeError: __init__() takes 1 positional argument but 2 were given
TL;DR:
jitclass
es in a jitclass
even if you have a list of those. You just need to correct the namespace numba.typed
-> numba.types
. jitclass
es or no-python numba.jit
functions. So you cannot append both instances of First
and Second
in the same list.numba.typed.List
exceptionYour update was almost correct. You need to use numba.types.List
not numba.typed.List
. The difference is a bit subtle but the numba.types
contains types for signatures while the numba.typed
namespace contains classes that can be instantiated and used in code.
So it will work if you use:
spec1 = [('ListA', nb.types.List(First.class_type.instance_type, reflected=True))]
With that change this code:
import numpy as np
import numba as nb
spec = [
('type', nb.int32),
('val', nb.float64[:]),
('result', nb.float64)
]
@nb.jitclass(spec)
class First:
def __init__(self):
self.type = 1
self.val = np.ones(100)
self.result = 0.
def sum(self):
self.result = np.sum(self.val)
spec1 = [('ListA', nb.types.List(First.class_type.instance_type, reflected=True))]
@nb.jitclass(spec1)
class Combined:
def __init__(self):
self.ListA = [First(), First()]
def sum(self):
for i, c in enumerate(self.ListA):
c.sum()
def getresult(self):
result = []
for i, c in enumerate(self.ListA):
result.append(c.result)
return result
C = Combined()
C.sum()
result = C.getresult()
print(result)
produces the output: [100.0, 100.0]
.
jitclass
here?However something to keep in mind here is that normal Python classes will probably be faster than the jitclass
-approach (or as fast):
import numpy as np
import numba as nb
class First:
def __init__(self):
self.type = 1
self.val = np.ones(100)
self.result = 0.
def sum(self):
self.result = np.sum(self.val)
class Combined:
def __init__(self):
self.ListA = [First(), First()]
def sum(self):
for i, c in enumerate(self.ListA):
c.sum()
def getresult(self):
result = []
for i, c in enumerate(self.ListA):
result.append(c.result)
return result
C = Combined()
C.sum()
C.getresult()
That's no problem if this is just for curiosity. But for production I would start with pure Python+NumPy and only apply numba if it's too slow and then only on the parts that are the bottleneck and only if numba is good at optimizing these things (numba is specialized tool at the moment, not a general purpose tool).
With numba in no-python (no-object) mode you need homogeneous lists. As far as I know numba 0.46 doesn't support lists containing different kinds of objects in jitclasses or nopython-jit methods. That means you cannot have one list containing First
and Second
instances.
So this cannot work:
self.List.append(First())
self.List.append(Second())
From the numba docs:
Creating and returning lists from JIT-compiled functions is supported, as well as all methods and operations. Lists must be strictly homogeneous: Numba will reject any list containing objects of different types, even if the types are compatible [...]