Search code examples
pythonmatlabscipyode

Solving systems of ODE in Python with time-varying parameters


I just started using Python, and I am trying to transfer my Matlab code to Python.

I want to run my model using parameters with different values.

P = [0, 0.5, 0.7];
d = [0.001, 0.002, 0.003];
B = [0.0095, 0.0080, 0.0070];
G = [0.001, 0.002, 0.003];
A = [0.001, 0.002, 0.003];

In Matlab, I can easily execute the code three times using the different parameters. See sample code. NOTE: I used Zombie_Apocalypse_ODEINT code as an example: - https://scipy-cookbook.readthedocs.io/items/Zombie_Apocalypse_ODEINT.html

S0 = 500;             % initial population
Z0 = 0;                 % initial zombie population
R0 = 0  ;               % initial death population

tspan = [0, 5];
y0 = [S0, Z0, R0]; 


P = [0, 0.5, 0.7];
d = [0.001, 0.002, 0.003];
B = [0.0095, 0.0080, 0.0070];
G = [0.001, 0.002, 0.003];
A = [0.001, 0.002, 0.003];

%[time, dxdt] = System_dynamics(tspan,y0,P, d, B,G, A);

for i = 1: 3

    [time, dxdt] = System_dynamics(tspan,y0,P(i), d(i), B(i),G(i), A(i));

end

function [time, dxdt] = System_dynamics(tspan,y0,P, d, B,G, A)

[time, dxdt] = ode23(@solve_ode,tspan,y0);

 function dxdt = solve_ode(t,y)

     Si = y(1);
     Zi = y(2);
     Ri = y(3); 

     f0 = P - B*Si*Zi - d*Si ;
     f1 = B*Si*Zi + G*Ri - A*Si*Zi;
     f2 = d*Si + A*Si*Zi - G*Ri;

     dxdt = [f0, f1, f2]';

 end


end
 

In Python, I have not seen any documentation on how to implement similar code. With the Matlab nested function, I can easily call the outer function using a for loop. As demonstrated in the above code.

I will like to modify the python code below to achieve a similar result to my Matlab code.

The python code is shown below:

import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import odeint
plt.ion()
plt.rcParams['figure.figsize'] = 10, 8

P = 0      # birth rate
d = 0.0001  # natural death percent (per day)
B = 0.0095  # transmission percent  (per day)
G = 0.0001  # resurect percent (per day)
A = 0.0001  # destroy percent  (per day)

# solve the system dy/dt = f(y, t)
def f(y, t):
     Si = y[0]
     Zi = y[1]
     Ri = y[2]
     # the model equations (see Munz et al. 2009)
     f0 = P - B*Si*Zi - d*Si
     f1 = B*Si*Zi + G*Ri - A*Si*Zi
     f2 = d*Si + A*Si*Zi - G*Ri
     return [f0, f1, f2]

# initial conditions
S0 = 500.              # initial population
Z0 = 0                 # initial zombie population
R0 = 0                 # initial death population
y0 = [S0, Z0, R0]     # initial condition vector
t  = np.linspace(0, 5., 1000)         # time grid

# solve the DEs
soln = odeint(f, y0, t)
S = soln[:, 0]
Z = soln[:, 1]
R = soln[:, 2]

Solution

  • This code will to generate three plotted figures...

    import numpy as np
    import matplotlib.pyplot as plt
    from scipy.integrate import odeint
    
    
    # solve the system dy/dt = f(y, t)
    def f(y, t, P, d, B, G, A):
        Si = y[0]
        Zi = y[1]
        Ri = y[2]
        # the model equations (see Munz et al. 2009)
        f0 = P - B*Si*Zi - d*Si
        f1 = B*Si*Zi + G*Ri - A*Si*Zi
        f2 = d*Si + A*Si*Zi - G*Ri
        return np.array([f0, f1, f2])
    
    # initial conditions
    S0 = 500.              # initial population
    Z0 = 0                 # initial zombie population
    R0 = 0                 # initial death population
    y0 = [S0, Z0, R0]     # initial condition vector
    t  = np.linspace(0, 5., 1000)         # time grid
    
    P = [0, 0.5, 0.7]
    d = [0.001, 0.002, 0.003]
    B = [0.0095, 0.0080, 0.0070]
    G = [0.001, 0.002, 0.003]
    A = [0.001, 0.002, 0.003]
    
    for k in range(len(P)):
        # solve the DEs
        args = (P[k], d[k], B[k], G[k], A[k])
        soln = odeint(f, y0, t, args=args)
        S = soln[:, 0]
        Z = soln[:, 1]
        R = soln[:, 2]
        
        suptitle = f'P = {P[k]:.1f}; d = {d[k]:.3f}; B = {B[k]:.4f}; ' + \
                   f'G = {G[k]:.3f}; A = {A[k]:.3f}'
        
        fig = plt.figure(f'k = {k}')
        fig.suptitle(suptitle)
        plt.plot(t, S, label='S')
        plt.plot(t, Z, label='Z')
        plt.legend()
    
    plt.show()