Search code examples
pythonnumpyjitnumba

Trouble with speeding up functions with numba JIT


I am new to numba's jit. For a personal project, I need to speed up functions that are similar to what will be shown below, though different for the purpose of writing standalone examples.

import numpy as np
from numba import jit, autojit, double, float64, float32, void

def f(n):
    k=0.
    for i in range(n):
        for j in range(n):
            k+= i+j

def f_with_return(n):
    k=0.
    for i in range(n):
        for j in range(n):
            k+= i+j
    return k

def f_with_arange(n):
    k=0.
    for i in np.arange(n):
        for j in np.arange(n):
            k+= i+j

def f_with_arange_and_return(n):
    k=0.
    for i in np.arange(n):
        for j in np.arange(n):
            k+= i+j  


#jit decorators
jit_f = jit(void(int32))(f)
jit_f_with_return = jit(int32(int32))(f_with_return)
jit_f_with_arange = jit(void(double))(f_with_arange)
jit_f_with_arange_and_return = jit(double(double))(f_with_arange_and_return)

And the benchmarks:

%timeit f(1000)
%timeit jit_f(1000)

10 loops, best of 3: 73.9 ms per loop / 1000000 loops, best of 3: 212 ns per loop

%timeit f_with_return(1000)
%timeit jit_f_with_return(1000)

10 loops, best of 3: 74.9 ms per loop / 1000000 loops, best of 3: 220 ns per loop

I don't understand these two:

%timeit f_with_arange(1000.0)
%timeit jit_f_with_arange(1000.0)

10 loops, best of 3: 175 ms per loop / 1 loops, best of 3: 167 ms per loop

%timeit f_with_arange_with_return(1000.0)
%timeit jit_f_with_arange_with_return(1000.0)

10 loops, best of 3: 174 ms per loop / 1 loops, best of 3: 172 ms per loop

I think I'm not giving the jit function the correct types for the output and input ? Just because the for loop is now running over a numpy.arange, and not a simple range anymore, I cannot get jit to make it faster. What is the issue here ?


Solution

  • Simply, numba doesn't know how to convert np.arange into a low level native loop, so it defaults to the object layer which is much slower and usually the same speed as pure python.

    A nice trick is to pass the nopython=True keyword argument to jit to see if it can compile everything without resorting to the object mode:

    import numpy as np
    import numba as nb
    
    def f_with_return(n):
        k=0.
        for i in range(n):
            for j in range(n):
                k+= i+j
        return k
    
    jit_f_with_return = nb.jit()(f_with_return)
    jit_f_with_return_nopython = nb.jit(nopython=True)(f_with_return)
    
    %timeit f_with_return(1000)
    %timeit jit_f_with_return(1000)
    %timeit jit_f_with_return_nopython(1000)
    

    The last two are the same speed on my machine and much faster than the un-jitted code. The two examples that you had questions about will raise an error with nopython=True since it can't compile np.arange at this point.

    See the following for more details:

    http://numba.pydata.org/numba-doc/0.17.0/user/troubleshoot.html#the-compiled-code-is-too-slow

    and for a list of supported numpy features with indications of what is and is not supported in nopython mode:

    http://numba.pydata.org/numba-doc/0.17.0/reference/numpysupported.html