For reasons, I need to implement the Runge-Kutta4 method in PyTorch (so no, I'm not going to use scipy.odeint
). I tried and I get weird results on the simplest test case, solving x'=x with x(0)=1 (analytical solution: x=exp(t)). Basically, as I reduce the time step, I cannot get the numerical error to go down. I'm able to do it with a simpler Euler method, but not with the Runge-Kutta 4 method, which makes me suspect some floating point issue here (maybe I'm missing some hidden conversion from double precision to single)?
import torch
import numpy as np
import matplotlib.pyplot as plt
def Euler(f, IC, time_grid):
y0 = torch.tensor([IC])
time_grid = time_grid.to(y0[0])
values = y0
for i in range(0, time_grid.shape[0] - 1):
t_i = time_grid[i]
t_next = time_grid[i+1]
y_i = values[i]
dt = t_next - t_i
dy = f(t_i, y_i) * dt
y_next = y_i + dy
y_next = y_next.unsqueeze(0)
values = torch.cat((values, y_next), dim=0)
return values
def RungeKutta4(f, IC, time_grid):
y0 = torch.tensor([IC])
time_grid = time_grid.to(y0[0])
values = y0
for i in range(0, time_grid.shape[0] - 1):
t_i = time_grid[i]
t_next = time_grid[i+1]
y_i = values[i]
dt = t_next - t_i
dtd2 = 0.5 * dt
f1 = f(t_i, y_i)
f2 = f(t_i + dtd2, y_i + dtd2 * f1)
f3 = f(t_i + dtd2, y_i + dtd2 * f2)
f4 = f(t_next, y_i + dt * f3)
dy = 1/6 * dt * (f1 + 2 * (f2 + f3) +f4)
y_next = y_i + dy
y_next = y_next.unsqueeze(0)
values = torch.cat((values, y_next), dim=0)
return values
# differential equation
def f(T, X):
return X
# initial condition
IC = 1.
# integration interval
def integration_interval(steps, ND=1):
return torch.linspace(0, ND, steps)
# analytical solution
def analytical_solution(t_range):
return np.exp(t_range)
# test a numerical method
def test_method(method, t_range, analytical_solution):
numerical_solution = method(f, IC, t_range)
L_inf_err = torch.dist(numerical_solution, analytical_solution, float('inf'))
return L_inf_err
if __name__ == '__main__':
Euler_error = np.array([0.,0.,0.])
RungeKutta4_error = np.array([0.,0.,0.])
indices = np.arange(1, Euler_error.shape[0]+1)
n_steps = np.power(10, indices)
for i, n in np.ndenumerate(n_steps):
t_range = integration_interval(steps=n)
solution = analytical_solution(t_range)
Euler_error[i] = test_method(Euler, t_range, solution).numpy()
RungeKutta4_error[i] = test_method(RungeKutta4, t_range, solution).numpy()
plots_path = "./plots"
a = plt.figure()
plt.xscale('log')
plt.yscale('log')
plt.plot(n_steps, Euler_error, label="Euler error", linestyle='-')
plt.plot(n_steps, RungeKutta4_error, label="RungeKutta 4 error", linestyle='-.')
plt.legend()
plt.savefig(plots_path + "/errors.png")
The result:
As you can see, the Euler method converges (slowly, as expected of a first order method). However, the Runge-Kutta4 method does not converge as the time step gets smaller and smaller. The error goes down initially, and then up again. What's the issue here?
The reason is indeed a floating point precision issue. torch
defaults to single precision, so once the truncation error becomes small enough, the total error is basically determined by the roundoff error, and reducing the truncation error further by increasing the number of steps <=> decreasing the time step doesn't lead to any decrease in the total error.
To fix this, we need to enforce double precision 64bit floats for all floating point torch
tensors and numpy
arrays. Note that the right way to do this is to use respectively torch.float64
and np.float64
rather than, e.g., torch.double
and np.double
, because the former are fixed-sized float values, (always 64bit) while the latter depend on the machine and/or compiler. Here's the fixed code:
import torch
import numpy as np
import matplotlib.pyplot as plt
def Euler(f, IC, time_grid):
y0 = torch.tensor([IC], dtype=torch.float64)
time_grid = time_grid.to(y0[0])
values = y0
for i in range(0, time_grid.shape[0] - 1):
t_i = time_grid[i]
t_next = time_grid[i+1]
y_i = values[i]
dt = t_next - t_i
dy = f(t_i, y_i) * dt
y_next = y_i + dy
y_next = y_next.unsqueeze(0)
values = torch.cat((values, y_next), dim=0)
return values
def RungeKutta4(f, IC, time_grid):
y0 = torch.tensor([IC], dtype=torch.float64)
time_grid = time_grid.to(y0[0])
values = y0
for i in range(0, time_grid.shape[0] - 1):
t_i = time_grid[i]
t_next = time_grid[i+1]
y_i = values[i]
dt = t_next - t_i
dtd2 = 0.5 * dt
f1 = f(t_i, y_i)
f2 = f(t_i + dtd2, y_i + dtd2 * f1)
f3 = f(t_i + dtd2, y_i + dtd2 * f2)
f4 = f(t_next, y_i + dt * f3)
dy = 1/6 * dt * (f1 + 2 * (f2 + f3) +f4)
y_next = y_i + dy
y_next = y_next.unsqueeze(0)
values = torch.cat((values, y_next), dim=0)
return values
# differential equation
def f(T, X):
return X
# initial condition
IC = 1.
# integration interval
def integration_interval(steps, ND=1):
return torch.linspace(0, ND, steps, dtype=torch.float64)
# analytical solution
def analytical_solution(t_range):
return np.exp(t_range, dtype=np.float64)
# test a numerical method
def test_method(method, t_range, analytical_solution):
numerical_solution = method(f, IC, t_range)
L_inf_err = torch.dist(numerical_solution, analytical_solution, float('inf'))
return L_inf_err
if __name__ == '__main__':
Euler_error = np.array([0.,0.,0.], dtype=np.float64)
RungeKutta4_error = np.array([0.,0.,0.], dtype=np.float64)
indices = np.arange(1, Euler_error.shape[0]+1)
n_steps = np.power(10, indices)
for i, n in np.ndenumerate(n_steps):
t_range = integration_interval(steps=n)
solution = analytical_solution(t_range)
Euler_error[i] = test_method(Euler, t_range, solution).numpy()
RungeKutta4_error[i] = test_method(RungeKutta4, t_range, solution).numpy()
plots_path = "./plots"
a = plt.figure()
plt.xscale('log')
plt.yscale('log')
plt.plot(n_steps, Euler_error, label="Euler error", linestyle='-')
plt.plot(n_steps, RungeKutta4_error, label="RungeKutta 4 error", linestyle='-.')
plt.legend()
plt.savefig(plots_path + "/errors.png")
Result:
Now, as we decrease the time step, the error of the RungeKutta4 approximation decreases with the correct rate.