Search code examples
pythonnumpyperformancenumba

Why is Jitted Numba function slower than original function?


I've written a function to create uniformly spaced points on a disk and since it's run quite often and on relatively large array I figured the application of numba would increase the speed significantly. However, upon running a quick test I've found that the numba function is more than twice as slow.

Is there a way to figure out what is slowing down the numba function?

Here's the function:

@njit(cache=True)
def generate_points_turbo(centre_point, radius, num_rings, x_axis=np.array([-1, 0, 0]), y_axis=np.array([0, 1, 0])):
    """
    Generate uniformly spaced points inside a circle
    Based on algorithm from:
    http://www.holoborodko.com/pavel/2015/07/23/generating-equidistant-points-on-unit-disk/
    
    Parameters
    ----------
    centre_point : np.ndarray (1, 3)
    radius : float/int
    num_rings : int
    x_axis : np.ndarray
    y_axis : np.ndarray

    Returns
    -------
    points : np.ndarray (n, 3)

    """
    if num_rings > 0:
        delta_R = 1 / num_rings
        ring_radii = np.linspace(delta_R, 1, int(num_rings)) * radius
        k = np.arange(num_rings) + 1
        points_per_ring = np.rint(np.pi / np.arcsin(1 / (2*k))).astype(np.int32)
        num_points = points_per_ring.sum() + 1
        ring_indices = np.zeros(int(num_rings)+1)
        ring_indices[1:] = points_per_ring.cumsum()
        ring_indices += 1
        points = np.zeros((num_points, 3))

        points[0, :] = centre_point

        for indx in range(len(ring_radii)):
            theta = np.linspace(0, 2 * np.pi, points_per_ring[indx]+1)
            points[ring_indices[indx]:ring_indices[indx+1], :] = ((ring_radii[indx] * np.cos(theta[1:]) * x_axis[:, None]).T
                     + (ring_radii[indx] * np.sin(theta[1:]) * y_axis[:, None]).T)
        return points + centre_point

And it's called like this:

centre_point = np.array([0,0,0])
radius = 1
num_rings = 15

generate_points_turbo(centre_point, radius, num_rings )

Would be great if someone knows why the function is slower when numba compiled or how to go about finding out what the bottleneck for the numba function is.

Update: Possible computer specific size dependence

It seems the numba function is working, but the cross-over between where it's faster and slower maybe be hardware specific.

%timeit generate_points(centre_point, 1, 2)
99.5 µs ± 932 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
%timeit generate_points_turbo(centre_point, 1, 2)
213 µs ± 8.4 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

%timeit generate_points(centre_point, 1, 20)
647 µs ± 11.2 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
%timeit generate_points_turbo(centre_point, 1, 20)
314 µs ± 8.74 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

%timeit generate_points(centre_point, 1, 200)
11.9 ms ± 375 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit generate_points_turbo(centre_point, 1, 200)
7.9 ms ± 243 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

After about 12-15 rings the numba function (*_turbo) starts to become a similar speed or faster on my machine, but the performance gains at larger size are smaller than expected. But seems like it is actually working, just that some part of the function is heavily size dependent.


Solution

  • I got rid of all your transpositions / newaxis / 3D stuff that you were not using and got a x20 boost compared to your original solution. I replaced a range with a prange too for good measure, as you don't care in which order your points are calculated.

    # Imports.
    import matplotlib.pyplot as plt
    from numba import njit, prange
    import numpy as np
    
    # "Turbo" function.
    @njit(cache=True)
    def generate_points_turbo(centre_point, radius, num_rings):
        """
        Generate uniformly spaced points inside a circle
        Based on algorithm from:
        http://www.holoborodko.com/pavel/2015/07/23/generating-equidistant-points-on-unit-disk/
    
        Parameters
        ----------
        centre_point : np.ndarray (2,)
        radius : float/int
        num_rings : int
        x_axis : np.ndarray
        y_axis : np.ndarray
    
        Returns
        -------
        points : np.ndarray (n, 2)
    
        """
        if not num_rings > 0:
            return
    
        delta_R = 1 / num_rings
        ring_radii = np.linspace(delta_R, 1, num_rings) # Use a unit circle that we will scale only at the end.
        k = np.arange(num_rings) + 1
        points_per_ring = np.rint(np.pi / np.arcsin(1 / (2*k))).astype(np.int32)
        num_points = points_per_ring.sum() + 1
    
        points = np.zeros((num_points, 2))
        n = 1 # n == 0 is the central point by design.
    
        for ring_number in prange(len(ring_radii)):
            r = ring_radii[ring_number] # The local radius between 1/num_rings and 1.
            points_on_this_ring = points_per_ring[ring_number]
            theta = np.linspace(0, 2 * np.pi, points_on_this_ring)
            points[n: n+points_on_this_ring, 0] = r * np.cos(theta)
            points[n: n+points_on_this_ring, 1] = r * np.sin(theta)
            n += points_on_this_ring
    
        return points * radius + centre_point
    
    
    
    # Test that the result is accurate.
    if __name__ == "__main__":
    
        centre_point = np.array([0, 0])
        radius = 3.14159
        num_rings = 10
    
        p = generate_points_turbo(centre_point, radius, num_rings)
        fig, ax = plt.subplots()
        ax.set_aspect(1)
        ax.scatter(*p.T)
        fig.show()
    

    enter image description here

     # Test time taken.
     >>> from timeit import timeit
     >>> from initial_code import generate_points_turbo as generate_points_turbo_stackoverflow
    
     >>> %timeit generate_points_turbo(centre_point, radius, num_rings)
     >>> 13.5 µs ± 21.2 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
     >>> %timeit generate_points_turbo_stackoverflow(np.array([0, 0, 0]), radius, num_rings)
     >>> 261 µs ± 1.03 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)