Search code examples
pythondictionaryjitnumba

Using Numba jitclass with input from dicts and tuples


I am working on optimizing some code that I have that is mostly contained within a single python class. It does very little manipulation of python objects so I thought using Numba would be a good match but I have a large number of parameters that I need during the creation of the object, and I don't think I fully understand Numba's relatively recent dict support (documentation here). The parameters I have are all either single floats or ints and are passed into the object, stored, and then used throughout the running of the code, like so:

import numpy as np
from numba import jitclass, float64

spec = [
    ('p', dict),
    ('shape', tuple),               # the shape of the array
    ('array', float64[:,:]),          # an array field
]

params_default = {
    par_1 = 1,
    par_2 = 0.5
    }

@jitclass(spec)
class myObj:
    def __init__(self,params = params_default,shape = (100,100)):
        self.p = params
        self.shape = shape
        self.array = self.p['par_2']*np.ones(shape)

    def inc_arr(self):
        self.array += self.p['par_1']*np.ones(shape)

There's quite a bit I don't think I understand about what Numba needs for this. If I want to optimize this with Numba using nopython mode, do I need to pass a spec into the jitclass decorator? How do I define the spec for the dictionary? Do I need to declare the shape tuple as well? I have looked at the documentation I found on the jitclass decorator as well as the dict numba documentation and I am not sure what to do. When I run the above code I get the following error:

TypeError: spec values should be Numba type instances, got <class 'dict'>

Do I need to include the dict elements in the spec somehow? It's not clear from the documentation what the correct syntax for that would be.

Alternately, is there a way to get Numba to infer the input types?


Solution

  • spec needs to be composed of the numba-specific types, not of the python types! So tuple and dict in spec must be typed numba types (and afaik only homogenous dicts are allowed).

    So either you specify your params_default dict in a jitted function, as shown here or you explicitly type a numba dict as shown here.

    For this case, I'll go with the latter approach:

    import numpy as np
    from numba import jitclass, float64
    
    # Explicitly define the types of the key and value:
    params_default = nb.typed.Dict.empty(
        key_type=nb.typeof('par_1'),
        value_type=nb.typeof(0.5)
    )
    
    # assign your default values
    params_default['par_1'] = 1.  # Same type required, thus setting to float
    params_default['par_2'] = .5
    
    spec = [
        ('p', nb.typeof(params_default)),
        ('shape', nb.typeof((100, 100))),               # the shape of the array
        ('array', float64[:, :]),          # an array field
    ]
    
    @jitclass(spec)
    class myObj:
        def __init__(self, params=params_default, shape=(100, 100)):
            self.p = params
            self.shape = shape
            self.array = self.p['par_2'] * np.ones(shape)
    
        def inc_arr(self):
            self.array += self.p['par_1'] * np.ones(shape)
    

    As already pointed out: The dict is, afaik, homogenously typed. Thus all keys/values must be of the same type. So storing int and float in the same dict won't work.