Search code examples
numpyodedifferential-equations

system of ODEs in matrix form


I am trying to figure out how to solve and plot the system dx/dt = Ax for a 2x2 matrix A. I don't really know how to do this. The code I currently have is as follows:

import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import odeint

def sys(x, t, A):
    x1, x2 = x    
    return [A @ x]

A = np.array([[1, 2], [3, 4]])


x0 = np.array([[1], [2]])

t = np.linspace(0, 100, 10000)

sol = odeint(sys, x0, t, A)

ax = plt.axes()
ax.plot(t, sol)
plt.show()

The error message is:

output = _odepack.odeint(func, y0, t, args, Dfun, col_deriv, ml, mu,
odepack.error: Extra arguments must be in a tuple.

Help on how to make this code work correctly would be very much appreciated. Please note, I am very new to coding, let alone coding differential equations. Thanks heaps.


Solution

  • You have to pass the other arguments in a tuple. Something like this:

    sol = odeint(sys, x0, t, args=(A,))
    

    See the documentation on scipy.integrate.odeint.