Search code examples
pythonnumpyjitmulticorenumba

How to make numba @jit use all cpu cores (parallelize numba @jit)


I am using numbas @jit decorator for adding two numpy arrays in python. The performance is so high if I use @jit compared with python.

However it is not utilizing all CPU cores even if I pass in @numba.jit(nopython = True, parallel = True, nogil = True).

Is there any way to to make use of all CPU cores with numba @jit.

Here is my code:

import time                                                
import numpy as np                                         
import numba                                               

SIZE = 2147483648 * 6                                      

a = np.full(SIZE, 1, dtype = np.int32)                     

b = np.full(SIZE, 1, dtype = np.int32)                     

c = np.ndarray(SIZE, dtype = np.int32)                     

@numba.jit(nopython = True, parallel = True, nogil = True) 
def add(a, b, c):                                          
    for i in range(SIZE):                                  
        c[i] = a[i] + b[i]                                 

start = time.time()                                        
add(a, b, c)                                               
end = time.time()                                          

print(end - start)                                        

Solution

  • You can pass parallel=True to any numba jitted function but that doesn't mean it's always utilizing all cores. You have to understand that numba uses some heuristics to make the code execute in parallel, sometimes these heuristics simply don't find anything to parallelize in the code. There's currently a pull request so that it issues a Warning if it wasn't possible to make it "parallel". So it's more like an "please make it execute in parallel if possible" parameter not "enforce parallel execution".

    However you can always use threads or processes manually if you really know you can parallelize your code. Just adapting the example of using multi-threading from the numba docs:

    #!/usr/bin/env python
    from __future__ import print_function, division, absolute_import
    
    import math
    import threading
    from timeit import repeat
    
    import numpy as np
    from numba import jit
    
    nthreads = 4
    size = 10**7  # CHANGED
    
    # CHANGED
    def func_np(a, b):
        """
        Control function using Numpy.
        """
        return a + b
    
    # CHANGED
    @jit('void(double[:], double[:], double[:])', nopython=True, nogil=True)
    def inner_func_nb(result, a, b):
        """
        Function under test.
        """
        for i in range(len(result)):
            result[i] = a[i] + b[i]
    
    def timefunc(correct, s, func, *args, **kwargs):
        """
        Benchmark *func* and print out its runtime.
        """
        print(s.ljust(20), end=" ")
        # Make sure the function is compiled before we start the benchmark
        res = func(*args, **kwargs)
        if correct is not None:
            assert np.allclose(res, correct), (res, correct)
        # time it
        print('{:>5.0f} ms'.format(min(repeat(lambda: func(*args, **kwargs),
                                              number=5, repeat=2)) * 1000))
        return res
    
    def make_singlethread(inner_func):
        """
        Run the given function inside a single thread.
        """
        def func(*args):
            length = len(args[0])
            result = np.empty(length, dtype=np.float64)
            inner_func(result, *args)
            return result
        return func
    
    def make_multithread(inner_func, numthreads):
        """
        Run the given function inside *numthreads* threads, splitting its
        arguments into equal-sized chunks.
        """
        def func_mt(*args):
            length = len(args[0])
            result = np.empty(length, dtype=np.float64)
            args = (result,) + args
            chunklen = (length + numthreads - 1) // numthreads
            # Create argument tuples for each input chunk
            chunks = [[arg[i * chunklen:(i + 1) * chunklen] for arg in args]
                      for i in range(numthreads)]
            # Spawn one thread per chunk
            threads = [threading.Thread(target=inner_func, args=chunk)
                       for chunk in chunks]
            for thread in threads:
                thread.start()
            for thread in threads:
                thread.join()
            return result
        return func_mt
    
    
    func_nb = make_singlethread(inner_func_nb)
    func_nb_mt = make_multithread(inner_func_nb, nthreads)
    
    a = np.random.rand(size)
    b = np.random.rand(size)
    
    correct = timefunc(None, "numpy (1 thread)", func_np, a, b)
    timefunc(correct, "numba (1 thread)", func_nb, a, b)
    timefunc(correct, "numba (%d threads)" % nthreads, func_nb_mt, a, b)
    

    I highlighted the parts which I changed, everything else was copied verbatim from the example. This utilizes all cores on my machine (4 core machine therefore 4 threads) but doesn't show a significant speedup:

    numpy (1 thread)       539 ms
    numba (1 thread)       536 ms
    numba (4 threads)      442 ms
    

    The lack of (much) speedup with multithreading in this case is that addition is a bandwidth-limited operation. That means it takes much more time to load the elements from the array and place the result in the result array than to do the actual addition.

    In these cases you could even see slowdowns because of parallel execution!

    Only if the functions are more complex and the actual operation takes significant time compared to loading and storing of array elements you'll see a big improvement with parallel execution. The example in the numba documentation is one like that:

    def func_np(a, b):
        """
        Control function using Numpy.
        """
        return np.exp(2.1 * a + 3.2 * b)
    
    @jit('void(double[:], double[:], double[:])', nopython=True, nogil=True)
    def inner_func_nb(result, a, b):
        """
        Function under test.
        """
        for i in range(len(result)):
            result[i] = math.exp(2.1 * a[i] + 3.2 * b[i])
    

    This actually scales (almost) with the number of threads because two multiplications, one addition and one call to math.exp is much slower than loading and storing results:

    func_nb = make_singlethread(inner_func_nb)
    func_nb_mt2 = make_multithread(inner_func_nb, 2)
    func_nb_mt3 = make_multithread(inner_func_nb, 3)
    func_nb_mt4 = make_multithread(inner_func_nb, 4)
    
    a = np.random.rand(size)
    b = np.random.rand(size)
    
    correct = timefunc(None, "numpy (1 thread)", func_np, a, b)
    timefunc(correct, "numba (1 thread)", func_nb, a, b)
    timefunc(correct, "numba (2 threads)", func_nb_mt2, a, b)
    timefunc(correct, "numba (3 threads)", func_nb_mt3, a, b)
    timefunc(correct, "numba (4 threads)", func_nb_mt4, a, b)
    

    Result:

    numpy (1 thread)      3422 ms
    numba (1 thread)      2959 ms
    numba (2 threads)     1555 ms
    numba (3 threads)     1080 ms
    numba (4 threads)      797 ms