Search code examples
pythonnumpyjitnumbaboolean-indexing

Python Numba jit function with if statement


I have a piecewise function with 3 parts that I'm trying to write in Python using Numba @jit instruction. The function is calculated over an array. The function is defined by:

@njit(parallel=True)
def f(x_vec):
    N=len(x_vec)
    y_vec=np.zeros(N)
    for i in prange(N):
        x=x_vec[i]
        if x<=2000:
            y=64/x
        elif x>=4000:
            y=np.log(x)
        else:
            y=np.log(1.2*x)
        y_vec[i]=y
    return y_vec

I'm using Numba to make this code very fast and run it on all 8 threads of my CPU.

Now, my question is, if I wanted to define each part of the function separately as f1, f2 and f3, and put those inside the if statements (and still benefit from Numba speed), how can I do that? The reason is that the subfunctions can be more complicated and I don't want to make my code hard to read. I want it to be as fast as this one (or slightly slower but not alot).

In order to test the function, we can use this array:

Np=10000000
x_vec=100*np.power(1e8/100,np.random.rand(Np))
%timeit f(x_vec)  #0.06sec on intel core i7 3610

For completionism, the following libraries are called:

import numpy as np
from numba import njit, prange

So in this case, the functions would be:

def f1(x):
    return 64/x
def f2(x):
    return np.log(x)
def f3(x):
    return np.log(1.2*x)

The actual functions are these, which are for smooth pipe friction factor for laminar, transition and turbulent regimes:

@njit
def f1(x):
    return 64/x

@njit
def f2(x):
    #x is the Reynolds number(Re), y is the Darcy friction(f)
    #for transition, we can assume Re=4000 (max possible friction)
    y=0.02
    y=(-2/np.log(10))*np.log(2.51/(4000*np.sqrt(y)))
    return 1/(y*y)

@njit
def f3(x): #colebrook-white approximation
    #x is the Reynolds number(Re), y is the Darcy friction(f)
    y=0.02
    y=(-2/np.log(10))*np.log(2.51/(x*np.sqrt(y)))
    return 1/(y*y)

Thanks for contributions from everyone. This is the numpy solution (the last tree lines are slow for some reason, but doesn't need warmup):

y = np.empty_like(x_vec)

a1=np.where(x_vec<=2000,True,False)
a3=np.where(x_vec>=4000,True,False)
a2=~(a1 | a3)

y[a1] = f1(x_vec[a1])
y[a2] = f2(x_vec[a2])
y[a3] = f3(x_vec[a3])

The fastest Numba solution, allowing for passing function names and taking advantage of prange (but hindered by jit warmup) is this, which can be as fast as the first solution (top of the question):

@njit(parallel=True)
def f(x_vec,f1,f2,f3):
    N = len(x_vec)
    y_vec = np.zeros(N)
    for i in prange(N):
        x=x_vec[i]
        if x<=2000:
            y=f1(x)
        elif x>=4000:
            y=f3(x)
        else:
            y=f2(x)
        y_vec[i]=y
    return y_vec

Solution

  • You can write f() to accept function parameters, e.g.:

    @njit
    def f(arr, f1, f2, f3):
        N = len(arr)
        y_vec = np.zeros(N)
        for i in range(N):
            x = x_vec[i]
            if x <= 2000:
                y = f1(x)
            elif x >= 4000:
                y = f2(x)
            else:
                y = f3(x)
            y_vec[i] = y
        return y_vec
    

    Make sure that the function you pass are Numba compatible.