Search code examples
pythonnumpynumbajit

Numba is not enhancing the performance


I am testing numba performance on some function that takes a numpy array, and compare:

import numpy as np
from numba import jit, vectorize, float64
import time
from numba.core.errors import NumbaWarning
import warnings

warnings.simplefilter('ignore', category=NumbaWarning)

@jit(nopython=True, boundscheck=False) # Set "nopython" mode for best performance, equivalent to @njit
def go_fast(a):     # Function is compiled to machine code when called the first time
    trace = 0.0
    for i in range(a.shape[0]):   # Numba likes loops
        trace += np.tanh(a[i, i]) # Numba likes NumPy functions
    return a + trace              # Numba likes NumPy broadcasting
   
class Main(object):
    def __init__(self) -> None:
        super().__init__()
        self.mat     = np.arange(100000000, dtype=np.float64).reshape(10000, 10000)

    def my_run(self):
        st = time.time()
        trace = 0.0
        for i in range(self.mat.shape[0]):   
            trace += np.tanh(self.mat[i, i]) 
        res = self.mat + trace
        print('Python Diration: ', time.time() - st)
        return res                           
    
    def jit_run(self):
        st = time.time()
        res = go_fast(self.mat)
        print('Jit Diration: ', time.time() - st)
        return res
        
obj = Main()
x1 = obj.my_run()
x2 = obj.jit_run()

The output is:

Python Diration:  0.2164750099182129
Jit Diration:  0.5367801189422607

How can I obtain an enhance version of this example ?


Solution

  • The slower execution time of the Numba implementation is due to the compilation time since Numba compile the function at the time it is used (only the first time unless the type of the argument change). It does that because it cannot know the type of the arguments before the function is called. Hopefully, you can specify the argument type to Numba so it can compile the function directly (when the decorator function is executed). Here is the resulting code:

    @njit('float64[:,:](float64[:,:])')
    def go_fast(a):
        trace = 0.0
        for i in range(a.shape[0]):
            trace += np.tanh(a[i, i])
        return a + trace
    

    Note that njit is a shortcut for jit+nopython=True and that boundscheck is already set to False by default (see the doc).

    On my machine this result in the same execution time for both Numpy and Numba. Indeed, the execution time is not bounded by the computation of the tanh function. It is bounded by the expression a + trace (for both Numba and Numpy). The same execution time is expected since both implement this the same way: they create a temporary new array to perform the addition. Creating a new temporary array is expensive because of page faults and the use of the RAM (a is fully read from the RAM and the temporary array is fully stored in RAM). If you want a faster computation, then you need to perform the operation in-place (this prevent page faults and expensive cache-line write allocations on x86 platforms).