Search code examples
pythonclassjitnumbafunction-parameter

How to use numba.jit with methods


Using numba.jit in python.

I can convert normal functions to jit-type and run:

from numba import jit

def sum(a, b):
    return a+b

func = jit(sum)
print(func(1, 2))

How to do this to methods? Something like this (this doesn't work and I know why).

from numba import jit

class some_class:
    def __init__(self, something = 0):
        self.number = something
    def get_num(self):
        return self.number

my_object = some_class(5)
func = jit(my_object.get_num)
print(my_object.func())

P.S. I've also tried decorators, it works but I can't use it for imported classes (the ones that I don't define myself), so I'm looking into this.


Solution

  • You cannot jit bound methods, but you can jit unbound methods (but only in object-mode):

    from numba import jit
    
    class some_class:
        def __init__(self, something = 0):
            self.number = something
        def get_num(self):
            return self.number
        func = jit(get_num)
    
    my_object = some_class(5)
    print(my_object.func())
    # 5
    

    Note that this doesn't use nopython mode, so you shouldn't expect any reasonable speed-ups. You could make the class itself a jitclass (that means all methods are nopython-jitted automatically) but it requires that you type the attributes:

    import numba as nb
    
    spec = [
        ('number', nb.int64),
    ]
    
    @nb.jitclass(spec)
    class some_class:
        def __init__(self, something):
            self.number = something
        def get_num(self):
            return self.number
    
    my_object = some_class(5)
    print(my_object.get_num())
    

    But for more complicated classes it will become very hard (or impossible) to use jitclass. In my experience the best way is to simply call the jitted functions from within the method:

    from numba import njit  # like jit but enforces nopython-mode!
    
    @njit
    def my_func(val):
        return val  # this example is a bit stupid, I hope your real code does more!
    
    class some_class:
        def __init__(self, something = 0):
            self.number = something
        def get_num(self):
            return my_func(self.number)
    
    my_object = some_class(5)
    print(my_object.get_num())
    

    It depends on how complex your class and/or your method is which approach should be used. In your case I wouldn't use numba at all because there's just not enough computationally expensive stuff that would compensate for the numba and jit overhead. If it were a bit more complicated I would use jitclass and if it were much more complicated I would use the jitted-function that is called from within the function approach. Personally I would never use jit for a method because that implicitly requires object-mode, so it's likely the jitted function is slower than the unjitted function.

    By the way: In Python you generally use property instead of get_* or set_* functions...