Search code examples
pythonperformancenumbaprocessing-efficiency

Why numba is slower than pure python in my code?


I'm kinda new to python and I was playing around with numba and wrote a code that runs slower than pure python in numba. in small numbers, pure python is faster around x4 times than numba and in large numbers, they run pretty much the same. what is making my code run slow in numba?

from numba import njit
@njit
def forr (q):
    p=0
    k=q
    n=0
    while k!=0:
            n += 1
            k=k//10
    
    h=(abs(q-n*9)+q-n*9)//2 
    for j in range(q,h,-1):
        
        s=0
        k=j
        while k!=0:
            s += k%10
            k=k//10
        
        if s+j==q:
            p=1
            print('Yes')
            break
    if p==0:
        print('No')


Solution

  • I think the reason that your Numba code runs slower is because of next things:

    1. Probably you measure time of first run of function, at the very first time Numba JIT-compiles code which can take seconds. To have correct time measurement you need to do first separate call to numba function in order to JIT-pre-compile it.
    2. You may be giving not big enough inputs (input number) hence your function takes very little time and numba functions have some overhead to start. If possible in your code you should put quite long-taking algorithms inside Numba functions, taking at least dozens of milliseconds to run.
    3. You may be measuring to few runs, you have to measure hundreds of runs of function in a loop to have more accurate results.
    4. You didn't put cache = True option into @njit decorator, this option will help to take pre-compiled code at each script run instead of compiling it from scratch.
    5. Print function call itself inside functions that take little time may occupy quite considerable time, because console operations are quite slow. Better to return results from function and print them outside Numba function.

    Taking all things above I implemented next code to measure your Numba code, I just added cache = True option and commented out print() calls for the time of measurement (not to spoil console with hundreds of words when measuring).

    Next code shows that Numba variant is 29x times faster on my laptop. Also next code needs to install one time pip modules through command pip install numba timerit.

    Try it online!

    import timerit, numba
    timerit.Timerit._default_asciimode = True
    
    def forr(q):
        p=0
        k=q
        n=0
        while k!=0:
                n += 1
                k=k//10
        
        h=(abs(q-n*9)+q-n*9)//2 
        for j in range(q,h,-1):
            
            s=0
            k=j
            while k!=0:
                s += k%10
                k=k//10
            
            if s+j==q:
                p=1
                #print('Yes')
                break
        if p==0:
            #print('No')
            pass
            
    nforr = numba.njit(cache = True)(forr)
    nforr(2) # Heat-up, precompile numba
    
    tb = None
    for f in [forr, nforr]:
        tim = timerit.Timerit(num = 99, verbose = 1)
        for t in tim:
            f(1 << 60)
        if tb is None:
            tb = tim.mean()
        else:
            print(f'speedup {round(tb / tim.mean(), 1)}x')
    

    Output:

    Timed best=1.029 ms, mean=1.040 +- 0.0 ms
    Timed best=35.300 us, mean=35.673 +- 0.3 us
    speedup 29.2x