Search code examples
pythonarraysnumpyodemodel-fitting

odr function for fitting ODEs: formatting the np array


i found this code ( Curve fitting to coupled ODEs ) and would like to apply it to my example. however it have a system of 3 ODE instead one like in the example. i think i have formatted the the np array wrong at some point because i receive the error message:
' raise OdrError("fcn does not output %s-shaped array" % y_s)

OdrError: fcn does not output [273]-shaped array'

import numpy as np
from scipy.integrate import odeint
import matplotlib.pyplot as plt
from scipy.odr import Model, Data, ODR

X0=np.array([1.88217580e+01,  6.39178479e+00,  2.17062151e+00,  7.37133388e-01,
        2.50327190e-01,  8.50099826e-02,  2.88690076e-02,  9.80376865e-03,
        3.32931467e-03,  1.13061872e-03,  3.83953088e-04,  1.30385145e-04,
        4.42782276e-05,  1.50356782e-05,  5.10477998e-06,  1.73382003e-06,
        5.89295908e-07,  2.00567131e-07,  6.90030549e-08,  2.57878131e-08,
        9.67160308e-09,  1.97119162e-09, -1.69012581e-09, -1.71545398e-09,
       -3.71472345e-10,  2.30848671e-10,  7.16248000e-11, -1.55728077e-10,
       -1.42665015e-10,  6.05398499e-12,  1.85058432e-10,  3.41362199e-10,
        3.88187908e-10,  3.51166400e-10,  2.65669133e-10,  1.60540390e-10,
        6.41768726e-11,  1.08933660e-11,  4.57027140e-12,  1.10891069e-11,
        2.56927243e-11,  2.78420609e-11,  1.18751279e-11, -3.62281809e-12,
       -8.46085327e-12, -1.39074100e-12,  4.91246015e-12,  7.30936087e-12,
        1.32752431e-12, -3.13381780e-12, -8.11612234e-12, -1.00280844e-11,
       -5.53945281e-12, -1.01747651e-13,  2.04901659e-12,  1.13451885e-12,
       -9.04619801e-13, -1.84246991e-12, -9.54025800e-13,  2.78887763e-13,
        1.14911302e-12,  6.48264819e-13, -3.46980469e-14, -6.75972274e-13,
       -6.37761429e-13, -3.82725056e-13,  1.36698883e-13,  6.92178214e-13,
        6.53658276e-13,  5.49382735e-13,  1.80455683e-13, -2.33472724e-13,
       -3.00186786e-13, -3.11933704e-13, -1.37129091e-13,  1.13089169e-13,
        1.79605942e-13,  3.01644773e-13,  2.48001098e-13,  8.43594772e-14,
       -2.49816805e-14,  2.30314000e-14,  1.47890730e-14,  3.08335407e-14,
        8.00840951e-14,  1.20170440e-13,  1.20712574e-13,  1.31143818e-13,
        1.30834746e-13,  1.00040346e-13,  4.87840502e-14])
Xp=np.array([0.        , 3.74954652, 5.54244266, 6.0479111 , 6.11417888,
       6.03392821, 5.90652069, 5.76563354, 5.62263923, 5.48133732,
       5.34295714, 5.20785681, 5.07610002, 4.94765199, 4.8224459 ,
       4.70040546, 4.58145249, 4.46550957, 4.3525007 , 4.24235175,
       4.13499035, 4.03034588, 3.9283496 , 3.82893457, 3.73203547,
       3.6375886 , 3.5455319 , 3.45580487, 3.36834858, 3.28310554,
       3.20001976, 3.11903664, 3.04010298, 2.96316691, 2.88817786,
       2.81508655, 2.74384496, 2.67440629, 2.60672491, 2.54075635,
       2.47645726, 2.41378539, 2.35269956, 2.29315964, 2.23512649,
       2.17856199, 2.12342898, 2.06969122, 2.01731341, 1.96626112,
       1.91650082, 1.8679998 , 1.82072621, 1.77464897, 1.72973781,
       1.68596322, 1.64329643, 1.60170942, 1.56117485, 1.5216661 ,
       1.48315719, 1.44562283, 1.40903836, 1.37337973, 1.33862352,
       1.30474688, 1.27172757, 1.23954388, 1.20817466, 1.1775993 ,
       1.14779772, 1.11875033, 1.09043804, 1.06284225, 1.03594484,
       1.00972811, 0.98417486, 0.95926828, 0.93499202, 0.91133011,
       0.88826703, 0.8657876 , 0.84387706, 0.82252101, 0.80170542,
       0.78141661, 0.76164125, 0.74236635, 0.72357923, 0.70526757,
       0.68741932])
Xc=np.array([0.        , 1.80507171, 0.66166362, 0.25919594, 0.12154879,
       0.07395521, 0.05696567, 0.05039004, 0.04737134, 0.04558043,
       0.04422586, 0.04303835, 0.04192599, 0.04085709, 0.03982044,
       0.0388118 , 0.03782928, 0.03687183, 0.03593868, 0.03502916,
       0.03414267, 0.03327862, 0.03243643, 0.03161556, 0.03081546,
       0.03003561, 0.0292755 , 0.02853462, 0.0278125 , 0.02710864,
       0.0264226 , 0.02575392, 0.02510217, 0.02446691, 0.02384772,
       0.02324421, 0.02265596, 0.0220826 , 0.02152376, 0.02097906,
       0.02044814, 0.01993066, 0.01942627, 0.01893465, 0.01845547,
       0.01798842, 0.01753318, 0.01708946, 0.01665698, 0.01623545,
       0.01582458, 0.0154241 , 0.01503376, 0.0146533 , 0.01428247,
       0.01392102, 0.01356872, 0.01322533, 0.01289064, 0.01256442,
       0.01224645, 0.01193653, 0.01163445, 0.01134001, 0.01105303,
       0.01077331, 0.01050067, 0.01023493, 0.00997591, 0.00972345,
       0.00947738, 0.00923754, 0.00900376, 0.0087759 , 0.00855381,
       0.00833734, 0.00812634, 0.00792069, 0.00772024, 0.00752486,
       0.00733443, 0.00714882, 0.0069679 , 0.00679157, 0.00661969,
       0.00645217, 0.00628888, 0.00612973, 0.0059746 , 0.0058234 ,
       0.00567603])

P=np.array([X0,Xc,Xp])    
tmax, Nt = 90, int(90)
# Times at which the solution is to be computed.
t = np.linspace(0, tmax, Nt+1)


def coupledODE(beta, x):
    ka,Cltp,Clpt,Cl = beta


    # Three coupled ODEs
    def conc (y, t) : 
        X0=-ka*y[0]
        Xc=ka*y[0]+Cltp*y[2]-Clpt*y[1]- Cl*y[1]
        Xp=Clpt*y[1]-Cltp*y[2]
        f= np.array([X0,Xc,Xp])
        return f


    # Initial conditions for y[0], y[1] and y[2]
    y0 = np.array([X0[0],Xc[0], Xp[0]])



    # Solve the equation
    y = odeint(conc, y0, x)

    #return y[:,1]
    # in case you are only fitting to experimental findings of ODE #1
    print('y1',y.shape)
    y=y.ravel()
    print('y2',y.shape)
    return y

    # in case you have experimental findings of all three ODEs

#data = Data(t, P)
# with P being experimental findings of ODE #1

data = Data(np.repeat(t, 3), P.ravel())
# with P being a (3,N) array of experimental findings of all ODEs

model = Model(coupledODE)
#guess = [0.1,0.1,0.1]
guess = [0.15,0.15,0.15,0.15]

odr = ODR(data, model, guess)
odr.set_job(2)
out = odr.run()
print(out.beta)
print(out.sd_beta)

f = plt.figure()
p = f.add_subplot(111)
p.plot(t, P[0], 'ro')
p.plot(t, coupledODE(out.beta, t))
plt.show()

Any help with this issue would be highly appreciated


Solution

  • The documentation of ODR is not very forthcoming in the non-scalar case, but in the end it is quite logical.

    The purpose of the ODR sub-package is to fit a function y=f(b,x) to a given data set of argument-value pairs (x[k], y[k]), k in range(n). As part of this the function needs to be repeatedly evaluated on the full input data set for different parameters b, thus f has to primarily support a "vectorized" evaluation.

    The model function f thus takes an n long input x for a scalar domain in n samples, or an array with shape (m,n) for input vectors with m components. Similarly the output y has to be an array of length n for scalar output or of shape (q,n) for vector output.

    While the n has to be the same in all instances, m and q are independent. In the present case, m=1 and q=3, that is, the input is scalar.

    The script has thus to be corrected to remove all flattening operations, there need to be some transpositions as the result of odeint has shape (n,3), which is also the required format for the plot command to plot multiple lines in one go.

    def coupledODE(beta, x):
        ka,Cltp,Clpt,Cl = beta
    
        # Three coupled ODEs
        def conc (y, t) : 
            X0=-ka*y[0]
            Xc=ka*y[0]+Cltp*y[2]-Clpt*y[1]- Cl*y[1]
            Xp=Clpt*y[1]-Cltp*y[2]
            f= np.array([X0,Xc,Xp])
            return f
    
        # Initial conditions for y[0], y[1] and y[2]
        y0 = np.array([X0[0],Xc[0], Xp[0]])
    
        # Solve the equation
        y = odeint(conc, y0, x)
        # Transpose from (N,q) to (q,N)
        return y.T
    
    data = Data(t, P)
    # with P being a (3,N) array of experimental findings of all ODEs
    
    model = Model(coupledODE)
    guess = [0.15,0.15,0.15,0.15]
    
    odr = ODR(data, model, guess)
    odr.set_job(2)
    out = odr.run()
    print(out.beta)
    print(out.sd_beta)
    
    f = plt.figure()
    p = f.add_subplot(111)
    t_plot = np.linspace(t[0],t[-1],300)
    F = coupledODE(out.beta, t_plot).T
    print(P.shape,F.shape)
    p.plot(t, P.T, 'ro', ms=2)
    p.plot(t_plot, F)
    plt.show()
    

    resulting in a parameter vector

    [1.08       0.04       1.74       3.13000001]
    

    and a picture with a visually good fit of the plots

    plot of the fit