Search code examples
numpyperformancenumpy-ndarraynumbalow-latency

Numba njit making function call involving long expressions extremely slow


I am writing a finite volume code to solve the inviscid, compressible Euler equations. As a part of this, I am performing what is known as the Cauchy-Kovalevskaya process. A code snippet is given below

def cauchyKovalevskaya(drho_x: npt.NDArray[np.float64], dq_x: npt.NDArray[np.float64], dE_x: npt.NDArray[np.float64], k: int, gamma: float) -> npt.NDArray:
    if k == 0:
        return np.array([drho_x[0], dq_x[0], dE_x[0]], dtype=np.float64)
    
    elif k == 1:
        rho_t = -dq_x[1]

        q_t = -dq_x[0]**2*drho_x[1]*gamma/(2*drho_x[0]**2) + 3*dq_x[0]**2*drho_x[1]/(2*drho_x[0]**2) + dq_x[0]*dq_x[1]*gamma/drho_x[0] - 3*dq_x[0]*dq_x[1]/drho_x[0] - dE_x[1]*gamma + dE_x[1]

        E_t = -dq_x[0]**3*drho_x[1]*gamma/drho_x[0]**3 + dq_x[0]**3*drho_x[1]/drho_x[0]**3 + 3*dq_x[0]**2*dq_x[1]*gamma/(2*drho_x[0]**2) - 3*dq_x[0]**2*dq_x[1]/(2*drho_x[0]**2) - dq_x[0]*dE_x[1]*gamma/drho_x[0] + dq_x[0]*drho_x[1]*dE_x[0]*gamma/drho_x[0]**2 - dq_x[1]*dE_x[0]*gamma/drho_x[0]

        return np.array([rho_t, q_t, E_t], dtype=np.float64)
    
    elif k == 2:
        

        rho_tt = dq_x[0]**2*drho_x[2]*gamma/(2*drho_x[0]**2) - 3*dq_x[0]**2*drho_x[2]/(2*drho_x[0]**2) - dq_x[0]**2*drho_x[1]**2*gamma/drho_x[0]**3 + 3*dq_x[0]**2*drho_x[1]**2/drho_x[0]**3 + 2*dq_x[0]*dq_x[1]*drho_x[1]*gamma/drho_x[0]**2 - 6*dq_x[0]*dq_x[1]*drho_x[1]/drho_x[0]**2 - dq_x[0]*dq_x[2]*gamma/drho_x[0] + 3*dq_x[0]*dq_x[2]/drho_x[0] - dq_x[1]**2*gamma/drho_x[0] + 3*dq_x[1]**2/drho_x[0] + dE_x[2]*gamma - dE_x[2]

        q_tt = dq_x[0]**3*drho_x[2]*gamma**2/(2*drho_x[0]**3) + dq_x[0]**3*drho_x[2]*gamma/drho_x[0]**3 - 7*dq_x[0]**3*drho_x[2]/(2*drho_x[0]**3) - 3*dq_x[0]**3*drho_x[1]**2*gamma**2/(2*drho_x[0]**4) - 3*dq_x[0]**3*drho_x[1]**2*gamma/drho_x[0]**4 + 21*dq_x[0]**3*drho_x[1]**2/(2*drho_x[0]**4) + 5*dq_x[0]**2*dq_x[1]*drho_x[1]*gamma**2/(2*drho_x[0]**3) + 8*dq_x[0]**2*dq_x[1]*drho_x[1]*gamma/drho_x[0]**3 - 45*dq_x[0]**2*dq_x[1]*drho_x[1]/(2*drho_x[0]**3) - dq_x[0]**2*dq_x[2]*gamma**2/(2*drho_x[0]**2) - 5*dq_x[0]**2*dq_x[2]*gamma/(2*drho_x[0]**2) + 6*dq_x[0]**2*dq_x[2]/drho_x[0]**2 - dq_x[0]*dq_x[1]**2*gamma**2/drho_x[0]**2 - 5*dq_x[0]*dq_x[1]**2*gamma/drho_x[0]**2 + 12*dq_x[0]*dq_x[1]**2/drho_x[0]**2 + 3*dq_x[0]*dE_x[2]*gamma/drho_x[0] - 3*dq_x[0]*dE_x[2]/drho_x[0] - dq_x[0]*drho_x[1]*dE_x[1]*gamma**2/drho_x[0]**2 - 2*dq_x[0]*drho_x[1]*dE_x[1]*gamma/drho_x[0]**2 + 3*dq_x[0]*drho_x[1]*dE_x[1]/drho_x[0]**2 - dq_x[0]*drho_x[2]*dE_x[0]*gamma**2/drho_x[0]**2 + dq_x[0]*drho_x[2]*dE_x[0]*gamma/drho_x[0]**2 + 2*dq_x[0]*drho_x[1]**2*dE_x[0]*gamma**2/drho_x[0]**3 - 2*dq_x[0]*drho_x[1]**2*dE_x[0]*gamma/drho_x[0]**3 + dq_x[1]*dE_x[1]*gamma**2/drho_x[0] + 2*dq_x[1]*dE_x[1]*gamma/drho_x[0] - 3*dq_x[1]*dE_x[1]/drho_x[0] - 2*dq_x[1]*drho_x[1]*dE_x[0]*gamma**2/drho_x[0]**2 + 2*dq_x[1]*drho_x[1]*dE_x[0]*gamma/drho_x[0]**2 + dq_x[2]*dE_x[0]*gamma**2/drho_x[0] - dq_x[2]*dE_x[0]*gamma/drho_x[0]

        E_tt = dq_x[0]**4*drho_x[2]*gamma**2/(4*drho_x[0]**4) + 2*dq_x[0]**4*drho_x[2]*gamma/drho_x[0]**4 - 9*dq_x[0]**4*drho_x[2]/(4*drho_x[0]**4) - dq_x[0]**4*drho_x[1]**2*gamma**2/drho_x[0]**5 - 8*dq_x[0]**4*drho_x[1]**2*gamma/drho_x[0]**5 + 9*dq_x[0]**4*drho_x[1]**2/drho_x[0]**5 + dq_x[0]**3*dq_x[1]*drho_x[1]*gamma**2/drho_x[0]**4 + 37*dq_x[0]**3*dq_x[1]*drho_x[1]*gamma/(2*drho_x[0]**4) - 39*dq_x[0]**3*dq_x[1]*drho_x[1]/(2*drho_x[0]**4) - 7*dq_x[0]**3*dq_x[2]*gamma/(2*drho_x[0]**3) + 7*dq_x[0]**3*dq_x[2]/(2*drho_x[0]**3) - 21*dq_x[0]**2*dq_x[1]**2*gamma/(2*drho_x[0]**3) + 21*dq_x[0]**2*dq_x[1]**2/(2*drho_x[0]**3) - dq_x[0]**2*dE_x[2]*gamma**2/(2*drho_x[0]**2) + 3*dq_x[0]**2*dE_x[2]*gamma/drho_x[0]**2 - 3*dq_x[0]**2*dE_x[2]/(2*drho_x[0]**2) + dq_x[0]**2*drho_x[1]*dE_x[1]*gamma**2/(2*drho_x[0]**3) - 15*dq_x[0]**2*drho_x[1]*dE_x[1]*gamma/(2*drho_x[0]**3) + 3*dq_x[0]**2*drho_x[1]*dE_x[1]/drho_x[0]**3 - dq_x[0]**2*drho_x[2]*dE_x[0]*gamma**2/(2*drho_x[0]**3) - 3*dq_x[0]**2*drho_x[2]*dE_x[0]*gamma/(2*drho_x[0]**3) + 3*dq_x[0]**2*drho_x[1]**2*dE_x[0]*gamma**2/(2*drho_x[0]**4) + 9*dq_x[0]**2*drho_x[1]**2*dE_x[0]*gamma/(2*drho_x[0]**4) - dq_x[0]*dq_x[1]*dE_x[1]*gamma**2/drho_x[0]**2 + 8*dq_x[0]*dq_x[1]*dE_x[1]*gamma/drho_x[0]**2 - 3*dq_x[0]*dq_x[1]*dE_x[1]/drho_x[0]**2 - dq_x[0]*dq_x[1]*drho_x[1]*dE_x[0]*gamma**2/drho_x[0]**3 - 7*dq_x[0]*dq_x[1]*drho_x[1]*dE_x[0]*gamma/drho_x[0]**3 + 2*dq_x[0]*dq_x[2]*dE_x[0]*gamma/drho_x[0]**2 + 2*dq_x[1]**2*dE_x[0]*gamma/drho_x[0]**2 + dE_x[0]*dE_x[2]*gamma**2/drho_x[0] - dE_x[0]*dE_x[2]*gamma/drho_x[0] + dE_x[1]**2*gamma**2/drho_x[0] - dE_x[1]**2*gamma/drho_x[0] - drho_x[1]*dE_x[0]*dE_x[1]*gamma**2/drho_x[0]**2 + drho_x[1]*dE_x[0]*dE_x[1]*gamma/drho_x[0]**2


        return np.array([rho_tt, q_tt, E_tt], dtype=np.float64)


    else:
        raise Exception("Invalid order")

There are also statements for k = 3, 4, and 5 which I have excluded for brevity. Suffice to say that these lines become longer and longer (80000+ characters in a line) with loads of similar memory accesses, powers, and multiplications like above.

When I use njit with this code, a) it is slower than without njit (for k = 5 the difference is 500 to 600 times), and b) it gets slower and slower the higher the k I include elif statements for, regardless of which k I am evaluating the function for. As far as I can tell, there should be nothing in the function which makes the performance using Numba so bad.


Solution

  • TL;DR: you should avoid providing such a big code directly to Numba and compile the caller function with Numba. If you cannot, the only possible optimization is to pre-allocate the output (or actually not to use Numba but C extension directly without checking inputs -- which is unsafe).


    Optimizing the compilation time

    First of all, the compiled code is pretty huge, especially if the same logic is applied "for k = 3, 4, and 5". This is a problem for 2 reasons. Firstly, the bigger the code to compile, the slower it takes to build. Secondly, large compiled codes are still not optimal because the native compiled code needs to be fetched from the cache/RAM, decoded and then executed. The first two steps are expensive, especially when the code is large. Thus, loops can be faster sometimes because of the two first overheads.

    To be sure the compilation time is not a problem in the middle of your benchmark/application (due to a lazy compilation), you can compile the function eagerly by providing the target signature of the function. Note that Numba does not care about types annotations like npt.NDArray[np.float64] currently. Here is an example of signature:

    @nb.njit('(float64[::1], float64[::1], float64[::1], int64, float64)')
    

    On my machine, the functions takes 3.74 seconds to build. This is a lot for such a simple computing function and a significant time should be spent in the Numba Python code parsing and semantic analysis. While this time is paid only once when the function is defined thanks to the eager compilation, slowing down the initialization of your program so much is not reasonable to me. The key is to avoid replicated expression. Note this does not impact much the execution time of the compiled function because compilers (like Numba) can analyses replicated terms and avoid recomputing things. Also note that this makes the code more readable too (avoiding ugly lines of >1700 characters). Here is a bit better version:

    def cauchyKovalevskaya(drho_x: npt.NDArray[np.float64], dq_x: npt.NDArray[np.float64], dE_x: npt.NDArray[np.float64], k: int, gamma: float) -> npt.NDArray:
        g = gamma
        r0, r1, r2 = drho_x
        q0, q1, q2 = dq_x
        e0, e1, e2 = dE_x
    
        r0p2, r1p2 = r0**2, r1**2
        q0p2, q1p2 = q0**2, q1**2
        e0p2, e1p2 = e0**2, e1**2
        gp2 = g**2
        r0p3, r0p4 = r0**3, r0**4
        q0p3, q0p4 = q0**3, q0**4
    
        if k == 0:
            return np.array([r0, q0, e0], dtype=np.float64)
        
        elif k == 1:
            rho_t = -q1
            q_t = -q0p2*r1*g/(2*r0p2) + 3*q0p2*r1/(2*r0p2) + q0*q1*g/r0 - 3*q0*q1/r0 - e1*g + e1
            E_t = -q0p3*r1*g/r0p3 + q0p3*r1/r0p3 + 3*q0p2*q1*g/(2*r0p2) - 3*q0p2*q1/(2*r0p2) - q0*e1*g/r0 + q0*r1*e0*g/r0p2 - q1*e0*g/r0
            return np.array([rho_t, q_t, E_t], dtype=np.float64)
        
        elif k == 2:
            rho_tt = q0p2*r2*g/(2*r0p2) - 3*q0p2*r2/(2*r0p2) - q0p2*r1p2*g/r0p3 + 3*q0p2*r1p2/r0p3 + 2*q0*q1*r1*g/r0p2 - 6*q0*q1*r1/r0p2 - q0*q2*g/r0 + 3*q0*q2/r0 - q1p2*g/r0 + 3*q1p2/r0 + e2*g - e2
            q_tt = q0p3*r2*gp2/(2*r0p3) + q0p3*r2*g/r0p3 - 7*q0p3*r2/(2*r0p3) - 3*q0p3*r1p2*gp2/(2*r0p4) - 3*q0p3*r1p2*g/r0p4 + 21*q0p3*r1p2/(2*r0p4) + 5*q0p2*q1*r1*gp2/(2*r0p3) + 8*q0p2*q1*r1*g/r0p3 - 45*q0p2*q1*r1/(2*r0p3) - q0p2*q2*gp2/(2*r0p2) - 5*q0p2*q2*g/(2*r0p2) + 6*q0p2*q2/r0p2 - q0*q1p2*gp2/r0p2 - 5*q0*q1p2*g/r0p2 + 12*q0*q1p2/r0p2 + 3*q0*e2*g/r0 - 3*q0*e2/r0 - q0*r1*e1*gp2/r0p2 - 2*q0*r1*e1*g/r0p2 + 3*q0*r1*e1/r0p2 - q0*r2*e0*gp2/r0p2 + q0*r2*e0*g/r0p2 + 2*q0*r1p2*e0*gp2/r0p3 - 2*q0*r1p2*e0*g/r0p3 + q1*e1*gp2/r0 + 2*q1*e1*g/r0 - 3*q1*e1/r0 - 2*q1*r1*e0*gp2/r0p2 + 2*q1*r1*e0*g/r0p2 + q2*e0*gp2/r0 - q2*e0*g/r0
            E_tt = q0p4*r2*gp2/(4*r0p4) + 2*q0p4*r2*g/r0p4 - 9*q0p4*r2/(4*r0p4) - q0p4*r1p2*gp2/r0**5 - 8*q0p4*r1p2*g/r0**5 + 9*q0p4*r1p2/r0**5 + q0p3*q1*r1*gp2/r0p4 + 37*q0p3*q1*r1*g/(2*r0p4) - 39*q0p3*q1*r1/(2*r0p4) - 7*q0p3*q2*g/(2*r0p3) + 7*q0p3*q2/(2*r0p3) - 21*q0p2*q1p2*g/(2*r0p3) + 21*q0p2*q1p2/(2*r0p3) - q0p2*e2*gp2/(2*r0p2) + 3*q0p2*e2*g/r0p2 - 3*q0p2*e2/(2*r0p2) + q0p2*r1*e1*gp2/(2*r0p3) - 15*q0p2*r1*e1*g/(2*r0p3) + 3*q0p2*r1*e1/r0p3 - q0p2*r2*e0*gp2/(2*r0p3) - 3*q0p2*r2*e0*g/(2*r0p3) + 3*q0p2*r1p2*e0*gp2/(2*r0p4) + 9*q0p2*r1p2*e0*g/(2*r0p4) - q0*q1*e1*gp2/r0p2 + 8*q0*q1*e1*g/r0p2 - 3*q0*q1*e1/r0p2 - q0*q1*r1*e0*gp2/r0p3 - 7*q0*q1*r1*e0*g/r0p3 + 2*q0*q2*e0*g/r0p2 + 2*q1p2*e0*g/r0p2 + e0*e2*gp2/r0 - e0*e2*g/r0 + e1p2*gp2/r0 - e1p2*g/r0 - r1*e0*e1*gp2/r0p2 + r1*e0*e1*g/r0p2
            return np.array([rho_tt, q_tt, E_tt], dtype=np.float64)
    
        else:
            raise Exception("Invalid order")
    

    It is still far from being great but I cannot find a generic pattern for all sub-expressions (it is simpler to develop an expression than to factorize it).

    This function takes 1.55 second to build on my machine (with the same execution time). This is a 2.4 faster build.

    You can reduce this overhead further using the compilation flag cache=True.


    Setup and Benchmark

    Here is the setup used to measure the performance of your code:

    import numpy as np
    import numpy.typing as npt
    import numba as nb
    
    %%time
    @nb.njit('(float64[::1], float64[::1], float64[::1], int64, float64)')
    def cauchyKovalevskaya(drho_x: npt.NDArray[np.float64], dq_x: npt.NDArray[np.float64], dE_x: npt.NDArray[np.float64], k: int, gamma: float) -> npt.NDArray:
        # [...]
    
    drho_x = np.random.rand(3)
    dq_x = np.random.rand(3)
    dE_x = np.random.rand(3)
    gamma = 42.0
    %timeit -n 20_000 cauchyKovalevskaya(drho_x, dq_x, dE_x, 0, gamma)
    %timeit -n 20_000 cauchyKovalevskaya(drho_x, dq_x, dE_x, 1, gamma)
    %timeit -n 20_000 cauchyKovalevskaya(drho_x, dq_x, dE_x, 2, gamma)
    

    Regarding the execution time, here is the timings on my machine (with an eager compilation):

    1.27 µs ± 8.18 ns per loop (mean ± std. dev. of 7 runs, 20000 loops each)
    1.28 µs ± 16 ns per loop (mean ± std. dev. of 7 runs, 20000 loops each)
    1.37 µs ± 24.4 ns per loop (mean ± std. dev. of 7 runs, 20000 loops each)
    

    Performance Analysis and Solutions

    >90% of the time comes from overheads (like creating the output array, checking the input types and converting them to native types, calling a native function from CPython) and not computation. Pre-allocating the output helps to reduce the execution time a bit. That being said, there is no way to massively reduce the execution time as long as this function is called from CPython. This is the main bottleneck. You can compile the caller function with Numba to significantly reduce this overhead.

    On my machine, it looks like the computation takes only 0.1 µs.

    Once the caller function is compiled with Numba, note you can add the compilation flag fastmath=True so to accelerate math computations. That being said, this option is harmful if values can be NaN, -inf, +inf, sub-normals, or even if the associativity of the operation matters. For more information about this, please read this post. I do not recommend it in this case because I do not expect a huge speed up on the modified code.