Search code examples
pythonjitnumba

Python Numba/jit conditional and recursive (stack) use


All,

I'm using numba JIT to speed up my Python code, but the code should be functional even if numba & LLVM are not installed.

My first idea was to do this as follows:

use_numba = True
try:
    from numba import jit, int32
except ImportError, e:
    use_numba = False

def run_it(parameters):
    # do something
    pass

# define wrapper call function with optimizer
@jit
def run_it_with_numba(parameters):
    return run_it(parameters)

# [...]
# main program 
t_start = timeit.default_timer()

# this is the code I don't like 
if use_numba:
    res = run_it_with_numba(parameters)
else:
    res = run_it(parameters)

t_stop = timeit.default_timer()
print "Numba: ", use_numba, " Time: ", t_stop - t_start

This does not work as I had expected, because the compilation seems to apply only on the run_it_with_numba() function -which basically does nothing- but not on the subroutines called from that function.

The results only get better when I apply @jit on the function that contains the workload.

Is there a chance to avoid the wrapper function and the if-clause in the main program?

Is there a way to tell to Numba to optimize also the subroutines that are called from my entry function? Because run_it() also contains some function calls and I expected @jit to deal with that.

cu, Ale


Solution

  • I think you want to do this in a different way. Instead of wrapping the method, just optionally alias it. For example using an dummy method to allow actual timings:

    import numpy as np
    import timeit 
    
    use_numba = False
    try:
        import numba as nb
    except ImportError, e:
        use_numba = False
    
    def _run_it(a, N):
        s = 0.0
        for k in xrange(N):
            s += k / np.sin(a)
    
        return s
    
    # define wrapper call function with optimizer
    if use_numba:
        print 'Using numba'
        run_it = nb.jit()(_run_it)
    else:
        print 'Falling back to python'
        run_it = _run_it
    
    if __name__ == '__main__':
        print timeit.repeat('run_it(50.0, 100000)', setup='from __main__ import run_it', repeat=3, number=100)
    

    Running this with the use_numba flag as True:

    $ python nbtest.py
    Using numba
    [0.18746304512023926, 0.15185213088989258, 0.1636970043182373]
    

    and as False:

    $ python nbtest.py
    Falling back to python
    [9.707707166671753, 9.779848098754883, 9.770231008529663]
    

    or in the iPython notebook using the nice %timeit magic:

    run_it_numba = nb.jit()(_run_it)
    
    %timeit _run_it(50.0, 10000)
    100 loops, best of 3: 9.51 ms per loop
    
    %timeit run_it_numba(50.0, 10000)  
    10000 loops, best of 3: 144 µs per loop
    

    Note that when timing numba methods, timing a single execution of the method will take into account the time it takes numba to jit the method. All subsequent runs will be much faster.