I'm trying to speed up some calculations using the Numba. The call of the nnleapfrog_integrate
function lead to segmentation fault and crash of the Python process. The function works fine if the parallel=True
flag is removed from its jit decorator. But then it runs in single thread. I want this fuction to run as fast as possible, thus I want it to run in multi threads to utilize all the CPU cores.
from numba import jit, prange
import numpy as np
@jit('Tuple((f8[:,:,::1],f8[:,:,::1]))(f8[:,::1], f8[:,::1], f8[::1], i8, i8, i8, f8, f8)', nopython=True, parallel=True)
def nnleapfrog_integrate(pos, vel, mass, i_steps, r_steps, dt, G, softening):
N = pos.shape[0]
pos_data = np.zeros((int(np.ceil(i_steps/r_steps)), N, 3))
vel_data = np.zeros((int(np.ceil(i_steps/r_steps)), N, 3))
data_idx = 0
acc = np.zeros((N,3))
for s in range(i_steps):
vel += acc * dt/2.0
pos += vel * dt
for i in prange(N):
acc[i,0] = 0
acc[i,1] = 0
acc[i,2] = 0
for j in range(N):
dx = pos[j,0] - pos[i,0]
dy = pos[j,1] - pos[i,1]
dz = pos[j,2] - pos[i,2]
inv_r3 = (dx**2 + dy**2 + dz**2 + softening**2)**(-1.5)
acc[i,0] += G * (dx * inv_r3) * mass[j]
acc[i,1] += G * (dy * inv_r3) * mass[j]
acc[i,2] += G * (dz * inv_r3) * mass[j]
vel += acc * dt/2.0
if s % r_steps == 0:
pos_data[data_idx] = pos
vel_data[data_idx] = vel
data_idx += 1
return pos_data, vel_data
N = 10
dt = 60
pos = np.random.rand(N, 3)
vel = np.random.rand(N, 3)
m = np.random.rand(N)
softening = 1e3
G = 6.67430e-11
t_max = 3600*24*30
i_steps = int(t_max/dt)
r_steps = int(3600*24/dt)
r_i, v_i = nnleapfrog_integrate(pos, vel, m, i_steps, r_steps, dt, G, softening)
Because only the for i in prange(N):
loop is suitable for parallelization, so I have separated it to the separate function getAcc
which is works fine with the parallel=True
flag and utilizes all the CPU cores.
from numba import jit, prange
import numpy as np
@jit('f8[:, ::1](f8[:, ::1], f8[::1], f8, f8)', nopython=True, parallel=True)
def getAcc( pos, mass, G, softening ):
N = pos.shape[0]
a = np.zeros((N,3))
for i in prange(N):
for j in range(N):
dx = pos[j,0] - pos[i,0]
dy = pos[j,1] - pos[i,1]
dz = pos[j,2] - pos[i,2]
inv_r3 = (dx**2 + dy**2 + dz**2 + softening**2)**(-1.5)
a[i,0] += G * (dx * inv_r3) * mass[j]
a[i,1] += G * (dy * inv_r3) * mass[j]
a[i,2] += G * (dz * inv_r3) * mass[j]
return a
@jit('Tuple((f8[:,:,::1],f8[:,:,::1]))(f8[:,::1], f8[:,::1], f8[::1], i8, i8, i8, f8, f8)', nopython=True)
def nleapfrog_integrate(pos, vel, mass, i_steps, r_steps, dt, G, softening):
N = pos.shape[0]
pos_data = np.zeros((int(np.ceil(i_steps/r_steps)), N, 3))
vel_data = np.zeros((int(np.ceil(i_steps/r_steps)), N, 3))
data_idx = 0
acc = getAcc(pos, mass, G, softening)
for i in range(i_steps):
vel += acc * dt/2.0
pos += vel * dt
acc = getAcc( pos, mass, G, softening )
vel += acc * dt/2.0
if i % r_steps == 0:
pos_data[data_idx] = pos
vel_data[data_idx] = vel
data_idx += 1
return pos_data, vel_data
N = 10
dt = 60
pos = np.random.rand(N, 3)
vel = np.random.rand(N, 3)
m = np.random.rand(N)
softening = 1e3
G = 6.67430e-11
t_max = 3600*24*30
i_steps = int(t_max/dt)
r_steps = int(3600*24/dt)
r_i, v_i = nleapfrog_integrate(pos, vel, m, i_steps, r_steps, dt, G, softening)
But it turned out to be more than 3 times slower than the single threaded version of the original function in which this cycle was inlined.
In [4]: %timeit r_i, v_i = nleapfrog_integrate(pos, vel, m, i_steps, r_steps, dt, G, softening)
8.51 s ± 46.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [5]: %timeit r_i, v_i = nnleapfrog_integrate(pos, vel, m, i_steps, r_steps, dt, G, softening)
2.53 s ± 18.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Therefore for the best performance, I need the original function with the inlined for i in prange(N):
loop to run in multi threads.
The parallelisation of the i-based loop is not efficient because creating and synchronizing threads is expensive. Indeed, this overhead is usually at least dozens of microseconds on PC and often significantly even bigger on computing server (mainly because of the additional cores). The thing is there is i_steps=43200
iteration so this overhead will result in few seconds. There is not enough work to use threads efficiently with N=10
.
Besides, note that there is a bug in Numba 0.57.0 causing a segmentation fault on this code so I am not sure this is even safe to parallelize it.
Fortunately, the serial code can be optimized:
x * y**(-1.5)
is not efficient because Numba use the expensive exponential function pow
to compute it. You can use x / (y * sqrt(y))
instead. This is significantly faster because most CPUs have an integrated hardware unit to compute square root and division relatively efficiently.fastmath
option is not enabled so Numba cannot assume that x*y
is equal to y*x
preventing some optimizations. Enabling this flag can be dangerous, but the optimization can be done manually by pre-computing values in the inner loop.Here is the resulting optimized code:
@jit('Tuple((f8[:,:,::1],f8[:,:,::1]))(f8[:,::1], f8[:,::1], f8[::1], i8, i8, i8, f8, f8)', nopython=True)
def nnleapfrog_integrate(pos, vel, mass, i_steps, r_steps, dt, G, softening):
N = pos.shape[0]
pos_data = np.zeros((int(np.ceil(i_steps/r_steps)), N, 3))
vel_data = np.zeros((int(np.ceil(i_steps/r_steps)), N, 3))
data_idx = 0
acc = np.zeros((N,3))
for s in range(i_steps):
vel += acc * dt/2.0
pos += vel * dt
for i in range(N):
acc[i,0] = 0.0
acc[i,1] = 0.0
acc[i,2] = 0.0
for j in range(N):
dx = pos[j,0] - pos[i,0]
dy = pos[j,1] - pos[i,1]
dz = pos[j,2] - pos[i,2]
tmp1 = dx**2 + dy**2 + dz**2 + softening**2
tmp2 = G * mass[j] / (tmp1 * np.sqrt(tmp1))
acc[i,0] += tmp2 * dx
acc[i,1] += tmp2 * dy
acc[i,2] += tmp2 * dz
vel += acc * dt/2.0
if s % r_steps == 0:
pos_data[data_idx] = pos
vel_data[data_idx] = vel
data_idx += 1
return pos_data, vel_data
This code is about 5 times faster on my machine with a i5-9600KF processor. It runs in approximately 31 ms. This means every iteration of the encompassing loop takes only 0.72 µs (far smaller than the overhead of thread creation/synchronization).
Further optimizations include pre-computing G * mass[j]
and computing the division/sqrt using SIMD instruction. The former is easy to do and the later is a bit tricky, especially in Numba.