Search code examples
pythonodepdecomputationfinite-difference

How to write a solver for the vibration ODE using the current version of Devito?


I'm reading the devito_book/fdm-jupyter-book/notebooks/01_vib/vib_undamped.ipynb and the code in it seems not compatible with devito 4.8.3.

So I tried to rewrite it as:

import numpy as np
from devito import Constant, TimeFunction, Eq, solve, Operator, Grid, TimeDimension


def solver(I, w, dt, T):
    dt = float(dt)
    Nt = int(round(T / dt))
    t = TimeDimension('t', spacing=Constant('h_t'))
    grid = Grid(shape=(Nt + 1, ), time_dimension=t)
    u = TimeFunction(name='u', grid=grid, time_order=2)

    u.data[:] = I
    eqn = u.dt2 + w ** 2 * u
    stencil = Eq(u.forward, solve(eqn, u.forward))
    op = Operator(stencil)
    op.apply(h_t=dt, t_M=Nt - 1)
    return u.data, np.linspace(0, Nt * dt, Nt + 1)


I = 1
w = 2 * np.pi
dt = 0.05
num_periods = 5
P = 2 * np.pi / w
T = P * num_periods
u, t = solver(I, w, dt, T)

However, the data of returned u is:



The value does not update from one time node to the next. Could anyone help me solve this issue?

Thank you in advance!


Solution

  • import numpy as np
    import matplotlib.pyplot as plt
    from devito import Constant, TimeFunction, Eq, solve, Operator, Grid, TimeDimension
    
    
    def solver(I, w, dt, T):
        dt = float(dt)
        Nt = int(round(T / dt))
        t = TimeDimension('t', spacing=Constant('h_t'))
        grid = Grid(shape=(2, ), time_dimension=t)
        u = TimeFunction(name='u', grid=grid, time_order=2, save=Nt + 1)
    
        u.data[0:] = I
        u.data[1:] = (1 - 0.5 * dt ** 2 * w ** 2) * I
        eqn = u.dt2 + w ** 2 * u
        stencil = Eq(u.forward, solve(eqn, u.forward))
        op = Operator(stencil)
        op.apply(h_t=dt, t_M=Nt - 1)
        return np.array(u.data)[:, 0], np.linspace(0, Nt * dt, Nt + 1)
    
    
    def u_exact(t, I, w):
        return I * np.cos(w * t)
    
    
    def visualize(u, t, I, w):
        plt.plot(t, u, 'r--o')
        t_fine = np.linspace(0, t[-1], 1001)
        u_e = u_exact(t_fine, I, w)
        plt.plot(t_fine, u_e, 'b-')
        plt.legend(['numerical', 'exact'], loc='upper left')
        plt.xlabel('t')
        plt.ylabel('u')
        dt = t[1] - t[0]
        plt.title('dt=%g' % dt)
        umin = 1.2 * u.min()
        umax = -umin
        plt.axis((t[0], t[-1], umin, umax))
        plt.savefig('tmp.png')
        plt.savefig('tmp.pdf')
    
    
    I = 1
    w = 2 * np.pi
    dt = 0.05
    num_periods = 5
    P = 2 * np.pi / w
    T = P * num_periods
    u, t = solver(I, w, dt, T)
    visualize(u, t, I, w)
    

    Result