Search code examples
pythonclassnumba

python class and numba jitclass for codes with numba functions


At some point in my code, I call a Numba function and all subsequent computations are made with Numba jitted functions until post-processing steps.

Over the past days, I have been looking for an efficient way to send to the Numba part of the code all the variables (booleans, integers, floats, and float arrays mostly) while trying to keep the code readable and clear. In my case, that implies limiting the number of arguments and, if possible, regrouping some variables depending on the system they refer to.

I identified four ways to do this:

  1. brut force: sending all the variables one by one as arguments of the first called Numba function. I find this solution is not acceptable as it makes the code barely readable (very large list of arguments) and inconsistent with my wish to my wish of regrouping variables,
  2. Numba typed dictionaries (see this post for instance): I did not find this solution acceptable as I understand a given dictionary can only contain variables of similar types (dictionary of floats64 only for instance) while a given system may have related variables of different types. In addition, I observed a significant loss of performance (~ +10% computation times) with this option,
  3. Numba namedtuples: fairly easy to implement and use, but it is my understanding that these can only be efficiently used if defined within a Numba function and thus cannot be sent from a non jitted function/code to a jitted function without making it impossible to benefit from the cache=True option. This is a dealbreaker for me as compilation times may exceed the execution of the code itself.
  4. Numba @jitclass: I was initially reluctant to use classes for my code but it turns out it is very practical... but same as for the namedtuples, if an object from a @jitclass is initialized within a non jitted function, I observed that it becomes impossible to benefit from the cache=True option (see this post).

In the end, none of these four alternatives allow me to do what I wanted to. I am probably missing something...

Here is what I did in the end : I combined the use of regular python classes and Numba @jitclass in order to maintain the possibility to benefit from the cache=True option.

Here is my mwe:

import numba as nb
from numba import jit
from numba.experimental import jitclass

spec_cls = [
    ('a', nb.types.float64),
    ('b', nb.types.float64),
]

# python class
class ClsWear_py(object):
    def __init__(self, a, b):
        self.a = a
        self.b = b

# mirror Numba class
@jitclass(spec_cls)
class ClsWear(object):
    def __init__(self, a, b):
        self.a = a
        self.b = b
   
def function_python(obj):
    print('from the python class :', obj.a)
    # call of a Numba function => this is where I must list explicitly all the keys of the python class object
    oa, ob = function_numba(obj.a, obj.b)
    return obj, oa, ob

@jit(nopython=True)
def function_numba(oa, ob):
# at the beginning of the Numba function, the arguments are used to define the @jitclass object
    obj_nb = ClsWear(oa, ob)
    print('from the numba class :', obj_nb.a)
    return obj_nb.a, obj_nb.b

# main code :  
obj_py = ClsWear_py(11,22)

obj_rt, a, b = function_python(obj_py)

The output of this code is :

$ python mwe.py 
from the python class : 11
from the numba class : 11.0

On the plus side :

  • I have a clean data structure in python and Numba (use of classes)
  • I keep a fast running code and cache=True is working

But on the down side :

  • I must define python classes and their mirror in Numba
  • there remains one barely readable part of the code: the first call of a jitted function where all the content of my objects must be listed explicitly

Am I missing something ? Is there a more obvious way to do this ?


Solution

  • Based on your research (which I agree with), there doesn't seem to be an intuitive way to do this. So I'm going to suggest the least horrible workaround.

    The idea is to pass it as a tuple to the jit function and then convert it to the desired class in the jit function.

    from typing import NamedTuple
    from numba import njit
    
    
    class Config(NamedTuple):
        a: int
        b: float
        c: str
    
    
    @njit(cache=True)
    def f(config_values):
        config = Config(*config_values)  # Convert to a named tuple.
        return config.a
    
    
    def main():
        config = Config(1, 2.0, "3")
        f(tuple(config))  # Pass as a tuple.
        print("Cache:", f.stats)
    
    
    main()
    

    Result (2nd run):

    Cache: _CompileStats(cache_path=..., cache_hits=Counter({(Tuple((int64, float64, unicode_type)),): 1}), cache_misses=Counter())
    

    As you can see, it is correctly cached as a tuple.

    One of the issues with this workaround is that named tuples are read-only. So you cannot modify fields. In that case, you could do the same thing with a jitclass. Also note that you can create a mirror class for numba from its Python class counterpart, because @jitclass is a regular Python decorator.

    from numba import njit
    from numba.experimental import jitclass
    
    
    class Config:
        # These type hints are interpreted as the default specs for jitclass.
        # If you only use primitive types, this should be sufficient.
        a: int
        b: float
        c: str
    
        def __init__(self, a, b, c):
            self.a = a
            self.b = b
            self.c = c
    
        def as_tuple(self):
            return self.a, self.b, self.c
    
    
    JitConfig = jitclass(Config)
    # _JitConfig = jitclass(specs)(Config) if you need to specify the specs.
    
    
    @njit(cache=True)
    def f(config_values):
        config = JitConfig(*config_values)  # Convert to a jitclass.
        return config.a
    
    
    def main():
        # Since we use a tuple as an argument, it is not mandatory to use a jitclass here.
        config = Config(1, 2.0, "3")
    
        f(config.as_tuple())  # Pass as a tuple.
        print("Cache:", f.stats)
    
    
    main()
    

    You can also use overload to hide the jitclass entirely. Notice that in the code below, it is no longer necessary to explicitly use the jitclass.

    from numba import njit
    from numba.core.extending import overload
    from numba.experimental import jitclass
    
    
    class Config:
        a: int
        b: float
        c: str
    
        def __init__(self, a, b, c):
            self.a = a
            self.b = b
            self.c = c
    
        def as_tuple(self):
            return self.a, self.b, self.c
    
    
    _JitConfig = jitclass(Config)
    
    
    @overload(Config, strict=False)
    def overload_config_init(*args):
        def jit_config_init(*args):
            return _JitConfig(*args)
    
        return jit_config_init
    
    
    @njit(cache=True)
    def f(config_values):
    
        # Since it will be overloaded to a jitclass, you can use a Python class here.
        config = Config(*config_values)
    
        return config.a
    
    
    def main():
        config = Config(1, 2.0, "3")
    
        f(config.as_tuple())
        print("Cache:", f.stats)
    
    
    main()
    

    As for performance, it should be negligible. Here is the benchmark:

    import timeit
    from typing import NamedTuple
    
    from numba import njit
    from numba.core.extending import overload
    from numba.experimental import jitclass
    
    
    class NamedTupleContainer(NamedTuple):
        a0: int
        a1: int
        a2: int
        a3: int
        a4: int
        a5: float
        a6: float
        a7: float
        a8: float
        a9: float
    
    
    class JitclassContainer:
        a0: int
        a1: int
        a2: int
        a3: int
        a4: int
        a5: float
        a6: float
        a7: float
        a8: float
        a9: float
    
        def __init__(self, a0, a1, a2, a3, a4, a5, a6, a7, a8, a9):
            self.a0 = a0
            self.a1 = a1
            self.a2 = a2
            self.a3 = a3
            self.a4 = a4
            self.a5 = a5
            self.a6 = a6
            self.a7 = a7
            self.a8 = a8
            self.a9 = a9
    
        def as_tuple(self):
            return (
                self.a0,
                self.a1,
                self.a2,
                self.a3,
                self.a4,
                self.a5,
                self.a6,
                self.a7,
                self.a8,
                self.a9,
            )
    
    
    _JitclassContainer = jitclass(JitclassContainer)
    
    
    @overload(JitclassContainer)
    def overload_container_init(*args):
        def jit_container_init(*args):
            return _JitclassContainer(*args)
    
        return jit_container_init
    
    
    @njit(cache=True)
    def f_multi_args(a0, a1, a2, a3, a4, a5, a6, a7, a8, a9):
        return a0
    
    
    @njit(cache=True)
    def f_namedtuple(args):
        c = NamedTupleContainer(*args)
        return c.a0
    
    
    @njit(cache=True)
    def f_jitclass(args):
        c = JitclassContainer(*args)
        return c.a0
    
    
    def main():
        def benchmark(f):
            n_runs = 10000
            return min(timeit.repeat(f, repeat=100, number=n_runs)) / n_runs
    
        values = 1, 2, 3, 4, 5, 6.0, 7.0, 8.0, 9.0, 10.0
        a0, a1, a2, a3, a4, a5, a6, a7, a8, a9 = values
        named_tuple_container = NamedTupleContainer(*values)
        jitclass_container = JitclassContainer(*values)
    
        t = benchmark(lambda: f_multi_args(a0, a1, a2, a3, a4, a5, a6, a7, a8, a9))
        print(f"f_multi_args: {t * 10 ** 9:.0f} ns")
    
        t = benchmark(lambda: f_namedtuple(tuple(named_tuple_container)))
        print(f"f_namedtuple: {t * 10 ** 9:.0f} ns")
    
        t = benchmark(lambda: f_jitclass(jitclass_container.as_tuple()))
        print(f"f_jitclass  : {t * 10 ** 9:.0f} ns")
    
    
    main()
    

    Result:

    f_multi_args: 525 ns
    f_namedtuple: 695 ns
    f_jitclass  : 652 ns
    

    On my PC, the difference was less than 200 nanoseconds per function call with 10 arguments/fields.