Search code examples
python-3.xpytorchnumerical-methodsode

Why can't I get this Runge-Kutta solver to converge as the time step decreases?


For reasons, I need to implement the Runge-Kutta4 method in PyTorch (so no, I'm not going to use scipy.odeint). I tried and I get weird results on the simplest test case, solving x'=x with x(0)=1 (analytical solution: x=exp(t)). Basically, as I reduce the time step, I cannot get the numerical error to go down. I'm able to do it with a simpler Euler method, but not with the Runge-Kutta 4 method, which makes me suspect some floating point issue here (maybe I'm missing some hidden conversion from double precision to single)?

import torch
import numpy as np
import matplotlib.pyplot as plt

def Euler(f, IC, time_grid):
    y0 = torch.tensor([IC])
    time_grid = time_grid.to(y0[0])
    values = y0

    for i in range(0, time_grid.shape[0] - 1):
        t_i = time_grid[i]
        t_next = time_grid[i+1]
        y_i = values[i]
        dt = t_next - t_i
        dy = f(t_i, y_i) * dt
        y_next = y_i + dy
        y_next = y_next.unsqueeze(0)
        values = torch.cat((values, y_next), dim=0)

    return values

def RungeKutta4(f, IC, time_grid):

    y0 = torch.tensor([IC])
    time_grid = time_grid.to(y0[0])
    values = y0

    for i in range(0, time_grid.shape[0] - 1):
        t_i = time_grid[i]
        t_next = time_grid[i+1]
        y_i = values[i]
        dt = t_next - t_i
        dtd2 = 0.5 * dt
        f1 = f(t_i, y_i)
        f2 = f(t_i + dtd2, y_i + dtd2 * f1)
        f3 = f(t_i + dtd2, y_i + dtd2 * f2)
        f4 = f(t_next, y_i + dt * f3)
        dy = 1/6 * dt * (f1 + 2 * (f2 + f3) +f4)
        y_next = y_i + dy
        y_next = y_next.unsqueeze(0)
        values = torch.cat((values, y_next), dim=0)

    return values

# differential equation
def f(T, X):
    return X 

# initial condition
IC = 1.

# integration interval
def integration_interval(steps, ND=1):
    return torch.linspace(0, ND, steps)

# analytical solution
def analytical_solution(t_range):
    return np.exp(t_range)

# test a numerical method
def test_method(method, t_range, analytical_solution):
    numerical_solution = method(f, IC, t_range)
    L_inf_err = torch.dist(numerical_solution, analytical_solution, float('inf'))
    return L_inf_err


if __name__ == '__main__':

    Euler_error = np.array([0.,0.,0.])
    RungeKutta4_error = np.array([0.,0.,0.])
    indices = np.arange(1, Euler_error.shape[0]+1)
    n_steps = np.power(10, indices)
    for i, n in np.ndenumerate(n_steps):
        t_range = integration_interval(steps=n)
        solution = analytical_solution(t_range)
        Euler_error[i] = test_method(Euler, t_range, solution).numpy()
        RungeKutta4_error[i] = test_method(RungeKutta4, t_range, solution).numpy()

    plots_path = "./plots"
    a = plt.figure()
    plt.xscale('log')
    plt.yscale('log')
    plt.plot(n_steps, Euler_error, label="Euler error", linestyle='-')
    plt.plot(n_steps, RungeKutta4_error, label="RungeKutta 4 error", linestyle='-.')
    plt.legend()
    plt.savefig(plots_path + "/errors.png")

The result:

enter image description here

As you can see, the Euler method converges (slowly, as expected of a first order method). However, the Runge-Kutta4 method does not converge as the time step gets smaller and smaller. The error goes down initially, and then up again. What's the issue here?


Solution

  • The reason is indeed a floating point precision issue. torch defaults to single precision, so once the truncation error becomes small enough, the total error is basically determined by the roundoff error, and reducing the truncation error further by increasing the number of steps <=> decreasing the time step doesn't lead to any decrease in the total error.

    To fix this, we need to enforce double precision 64bit floats for all floating point torch tensors and numpy arrays. Note that the right way to do this is to use respectively torch.float64 and np.float64 rather than, e.g., torch.double and np.double, because the former are fixed-sized float values, (always 64bit) while the latter depend on the machine and/or compiler. Here's the fixed code:

    import torch
    import numpy as np
    import matplotlib.pyplot as plt
    
    def Euler(f, IC, time_grid):
    
        y0 = torch.tensor([IC], dtype=torch.float64)
        time_grid = time_grid.to(y0[0])
        values = y0
    
        for i in range(0, time_grid.shape[0] - 1):
            t_i = time_grid[i]
            t_next = time_grid[i+1]
            y_i = values[i]
            dt = t_next - t_i
            dy = f(t_i, y_i) * dt
            y_next = y_i + dy
            y_next = y_next.unsqueeze(0)
            values = torch.cat((values, y_next), dim=0)
    
        return values
    
    def RungeKutta4(f, IC, time_grid):
    
        y0 = torch.tensor([IC], dtype=torch.float64)
        time_grid = time_grid.to(y0[0])
        values = y0
    
        for i in range(0, time_grid.shape[0] - 1):
            t_i = time_grid[i]
            t_next = time_grid[i+1]
            y_i = values[i]
            dt = t_next - t_i
            dtd2 = 0.5 * dt
            f1 = f(t_i, y_i)
            f2 = f(t_i + dtd2, y_i + dtd2 * f1)
            f3 = f(t_i + dtd2, y_i + dtd2 * f2)
            f4 = f(t_next, y_i + dt * f3)
            dy = 1/6 * dt * (f1 + 2 * (f2 + f3) +f4)
            y_next = y_i + dy
            y_next = y_next.unsqueeze(0)
            values = torch.cat((values, y_next), dim=0)
    
        return values
    
        # differential equation
    def f(T, X):
        return X 
    
    # initial condition
    IC = 1.
    
    # integration interval
    def integration_interval(steps, ND=1):
        return torch.linspace(0, ND, steps, dtype=torch.float64)
    
    # analytical solution
    def analytical_solution(t_range):
        return np.exp(t_range, dtype=np.float64)
    
    # test a numerical method
    def test_method(method, t_range, analytical_solution):
        numerical_solution = method(f, IC, t_range)
        L_inf_err = torch.dist(numerical_solution, analytical_solution, float('inf'))
        return L_inf_err
    
    
    if __name__ == '__main__':
    
        Euler_error = np.array([0.,0.,0.], dtype=np.float64)
        RungeKutta4_error = np.array([0.,0.,0.], dtype=np.float64)
        indices = np.arange(1, Euler_error.shape[0]+1)
        n_steps = np.power(10, indices)
        for i, n in np.ndenumerate(n_steps):
            t_range = integration_interval(steps=n)
            solution = analytical_solution(t_range)
            Euler_error[i] = test_method(Euler, t_range, solution).numpy()
            RungeKutta4_error[i] = test_method(RungeKutta4, t_range, solution).numpy()
    
        plots_path = "./plots"
        a = plt.figure()
        plt.xscale('log')
        plt.yscale('log')
        plt.plot(n_steps, Euler_error, label="Euler error", linestyle='-')
        plt.plot(n_steps, RungeKutta4_error, label="RungeKutta 4 error", linestyle='-.')
        plt.legend()
        plt.savefig(plots_path + "/errors.png")
    

    Result:

    enter image description here

    Now, as we decrease the time step, the error of the RungeKutta4 approximation decreases with the correct rate.