Search code examples
pythondata-structuresnumba

What are the guidelines for using numba for a tree structure?


Edit: Forgot to run numba more than once (oops!)

Ive looked at the numba versions of namedtuple and Dict as potential solutions but they seem much slower (about 10000x slower) in comparison to their python counterparts.

import time
from numba import njit
from collections import namedtuple    

Alpha = namedtuple("Alpha", ["a", "b", "c"])    
Regions = namedtuple("Regions", ["a", "b"])    
State = namedtuple("State", ["H", "L"])    
Parameters = namedtuple("Parameters", ["alpha", "DB", "beta", "psi", "pi", "CC_opt"])    

def timer_func(func):
    def function_timer(*args, **kwargs):    
        start = time.time()    
        value = func(*args, **kwargs)    
        end = time.time()    
        runtime = end - start    
        msg = "{func} took {time} seconds to complete its execution."    
        print(msg.format(func=func.__name__, time=runtime))    
        return value    

    return function_timer    


@timer_func    
def build_params() -> Parameters:    
    alpha = Regions(    
        a=Alpha(0.5, 0.5, 0),    
        b=Alpha(0.5, 0.5, 0),    
    )    

    return Parameters(alpha=alpha, DB=State(0.0, 0.0), beta=0.8, psi=0.0, pi=0.5, CC_opt=1.0)    


@timer_func    
@njit    
def build_params_numba() -> Parameters:    
    alpha = Regions(    
        a=Alpha(0.5, 0.5, 0),    
        b=Alpha(0.5, 0.5, 0),    
    )    

    return Parameters(alpha=alpha, DB=State(0.0, 0.0), beta=0.8, psi=0.0, pi=0.5, CC_opt=1.0)    


if __name__ == "__main__":    
    build_params()    
    build_params_numba()

build_params took 3.814697265625e-06 seconds to complete its execution.

build_params_numba took 0.07473492622375488 seconds to complete its execution.

Edit:

build_params_numba took 3.5762786865234375e-06 seconds to complete its execution.


Solution

  • The biggest issue is the fact that you are measuring the first execution of build_params_numba, which includes the compilation (it is compiled Just-In-Time, just as you requested). This is like measuring the time-to-dinner between a classic meal and a microwave meal, but you're including the time to buy and install a microwave oven as part of the latter. Measure the second invocation of build_params_numba, when the compilation has been already completed, to see how the compiled function performs.

    The second issue is that numba might not be of much help with your code. AFAIK it is designed to speed up numerical algorithms and numpy code. By necessity, namedtuple and dict are Python data structures and numba has to treat them as such; so even though you requested nopython mode, Numba cannot oblige, as it only works when a native data type can be detected for all values in your code (I think — not 100% sure on this point though).