Search code examples
pythonscipymodelcurve-fitting

Integrating and fitting coupled ODE's for SIR modelling


In this case, there are 3 ODE's that describe a SIR model. The issue comes in I want to calculate which beta and gamma values are the best to fit onto the datapoints from the x_axis and y_axisvalues. The method I'm currently using is to integrate the ODE's using odeintfrom the scipy library and the curve_fit method also from the same library. In this case, how would you calculate the values for beta and gamma to fit the datapoints?

P.S. the current error is this: ValueError: operands could not be broadcast together with shapes (3,) (14,)


#initial values
S_I_R = (0.762/763, 1/763, 0)

x_axis = [m for m in range(1,15)]
y_axis = [3,8,28,75,221,291,255,235,190,125,70,28,12,5]


# ODE's that describe the system
def equation(SIR_Values,t,beta,gamma):
    Array = np.zeros((3))
    SIR = SIR_Values
    Array[0] = -beta * SIR[0] * SIR[1]
    Array[1] = beta * SIR[0] * SIR[1] - gamma * SIR[1]
    Array[2] = gamma * SIR[1]
    return Array


# Results = spi.odeint(equation,S_I_R,time)

#fitting the values
beta_values,gamma_values = curve_fit(equation, x_axis,y_axis)

Solution

  • # Starting values
    S0 = 762/763
    I0 = 1/763
    R0 = 0
    
    
    x_axis = np.array([m for m in range(0,15)])
    y_axis = np.array([1,3,8,28,75,221,291,255,235,190,125,70,28,12,5])
    y_axis = np.divide(y_axis,763)
    
    def sir_model(y, x, beta, gamma):
        S = -beta * y[0] * y[1] 
        R = gamma * y[1]
        I = beta * y[0] * y[1] - gamma * y[1]
        return S, I, R
    
    def fit_odeint(x, beta, gamma):
        return spi.odeint(sir_model, (S0, I0, R0), x, args=(beta, gamma))[:,1]
    
    popt, pcov = curve_fit(fit_odeint, x_axis, y_axis)
    beta,gamma = popt
    fitted = fit_odeint(x_axis,*popt)
    plt.plot(x_axis,y_axis, 'o', label = "infected per day")
    plt.plot(x_axis, fitted, label = "fitted graph")
    plt.xlabel("Time (in days)")
    plt.ylabel("Fraction of infected")
    plt.title("Fitted beta and gamma values")
    plt.legend()
    plt.show()