I am implementing a numerical method for solving stochastic differential equations using the Euler-Maruyama method.
What I have works, but it is not efficient. The reason is that because of the stochastic nature of the problem, I have many trajectories. Right now, I am solving them one by one. I have the feeling I should be able to parallelize them, as they are independent.
The working code looks like this
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import time
from numba import jit, njit
import os
def A(u):
x=u[0]
y=u[1]
z=u[2]
omega=1/2*np.sqrt((1+8*kappa*z*z))
A=np.array([[-2,omega,0],
[-omega,0,0],
[0,0,-kappa]])
du=A.dot(u)
return du
def B(u,w):
x=u[0]
y=u[1]
z=u[2]
g=np.sqrt(kappa*nth)
B=np.array([[0],
[1],
[1]])*g
return np.reshape(B*w,len(u0))
def SDE(A,B):
u = np.zeros((len(u0),Nmax+1,Mmax),dtype=np.complex64)
for m in range(Mmax):
u[:,0,m]=u0
for n in range(0,Nmax):
u[:,n+1,m] = u[:,n,m]+dt*A(u[:,n,m])+B(u[:,n,m],w[n,m])*np.sqrt(dt)
return u
#Parameters
kappa=0.05
nth=1.
gamma=1
Mmax=100 #number of trajectories
Tmax=10. ##max value for time
dt=0.05
Nmax=int(Tmax/dt) ##number of steps
t_list=np.arange(0,Tmax+dt/2,dt)
w = np.random.randn(Nmax+1,Mmax)
u0 = np.array([1., 0., np.sqrt(nth)/2])
u_t=SDE(A,B)
u_mean=np.mean(u_t,axis=2)
This code is a simplification of my real code, where I have a much larger dimension of the system and many more trajectories.
Notice how this is not efficient, because as I increase Mmax, I have to loop over them.
Ideally, I would like my solver to look something like
def SDE(A,B):
u = np.zeros((len(u0),Nmax+1,Mmax),dtype=np.complex64)
u[:,0,:]=u0
for n in range(0,Nmax):
u[:,n+1,:] = u[:,n,:]+dt*A(u[:,n,:])+B(u[:,n,:],w[n,:])*np.sqrt(dt)
return u
i.e., to be able to neglect the loop over m and just do it in a parallel fashion. However, naively doing so does not work.
Another ideal way to make it more efficient would be to use Numba. However, after many tries, I have not been able to implement njit with the SDE solver I define.
Here are the performance issues I found in the provided code:
A.dot(u)
is clearly overkill here since the A
matrix is a 3x3 matrix with only 4 non-zeros values. It is better to compute the expression manually, especially with Numba.u
. Its shape would be (Mmax,Nmax+1,len(u0))
. For more information, please read: AoS versus SoA.Besides, np.reshape(B*w,len(u0))
is confusing since B is of size 3 so u0
must be also of size 3 and the reshape
seems useless. Note that np.complex64
is for simple-precision (as stated in the comments).
Here is the resulting Numba code:
import numba as nb # new
@nb.njit('(complex128[:], complex128[::1])')
def A(u, res):
x=u[0]
y=u[1]
z=u[2]
omega = 0.5 * np.sqrt((1+8*kappa*z*z))
res[0] = -2 * x + omega * y
res[1] = -omega * x
res[2] = -kappa * z
return res
@nb.njit('(complex128[:], float64, complex128[::1])')
def B(u, w, res):
g = np.sqrt(kappa*nth)
res[0] = 0
res[1] = g * w
res[2] = g * w
return res
# No signature is provided so the first call will be much slower
# But providing a signature here is complicated since A and B are functions
@nb.njit
def SDE(A,B):
u = np.zeros((len(u0),Nmax+1,Mmax), dtype=np.complex128)
sqrt_dt = np.sqrt(dt)
for m in range(Mmax):
u[:,0,m] = u0
tmp1 = np.empty(3, dtype=np.complex128)
tmp2 = np.empty(3, dtype=np.complex128)
for n in range(0,Nmax):
A(u[:,n,m],tmp1)
B(u[:,n,m],w[n,m],tmp2)
for i in range(3):
u[i,n+1,m] = u[i,n,m] + dt * tmp1[i] + tmp2[i] * sqrt_dt
return u
This is about 500 times faster on my machine (with a i5-9600KF CPU).
I think there is no need to use multiple threads once the code is optimized since the computation is finally pretty fast. If this is not enough, you can add the flag parallel=True
and replace for m in range(Mmax)
with for m in nb.prange(Mmax)
. However, this will not scale well due to false sharing caused by the bad memory layout. As stated in the above list, you should swap the axis 1 and 3 so to fix this issue.
In the end, the final code should look like this once parallelized and with a better memory layout:
# A and B are the same as before
# No signature is provided so the first call will be much slower
# But providing a signature here is complicated since A and B are functions
@nb.njit(parallel=True)
def SDE(A,B):
u = np.zeros((Mmax,Nmax+1,len(u0)), dtype=np.complex128)
sqrt_dt = np.sqrt(dt)
for m in nb.prange(Mmax):
u[m,0,:] = u0
tmp1 = np.empty(3, dtype=np.complex128)
tmp2 = np.empty(3, dtype=np.complex128)
for n in range(0,Nmax):
A(u[m,n,:],tmp1)
B(u[m,n,:],w[m,n],tmp2)
for i in range(3):
u[m,n+1,i] = u[m,n,i] + dt * tmp1[i] + tmp2[i] * sqrt_dt
return u
# [...] same code
w = np.random.randn(Nmax+1,Mmax).T.copy()
u0 = np.array([1., 0., np.sqrt(nth)/2])
u_t=SDE(A,B)
u_mean=np.mean(u_t,axis=0)
This code is about 2000 times faster.
Note w
is a global variable so it is considered as a compile-time constant by Numba (ie. it should never change during the application life-time), you should pass it in parameter if it is not the case.
By the way, note that Julia might be a better language for such a computation since the standard implementation is a JIT-compiler and we can easily avoid the creation of new temporary arrays (though creating new array is still quite expensive even in Julia).