Search code examples
pythonscipydifferential-equations

Fit a differential equation using scipy, getting "object too deep for desired array"


I'm trying to fit a curve to a differential equation. For the sake of simplicity, I'm just doing the logistic equation here. I wrote the code below but I get an error shown below it. I'm not quite sure what I'm doing wrong.

import numpy as np
import pandas as pd
import scipy.optimize as optim
from scipy.integrate import odeint

df_yeast = pd.DataFrame({'cd': [9.6, 18.3, 29., 47.2, 71.1, 119.1, 174.6, 257.3, 350.7, 441., 513.3, 559.7, 594.8, 629.4, 640.8, 651.1, 655.9, 659.6], 'td': np.arange(18)})

N0 = 1
parsic = [5, 2]

def logistic_de(t, N, r, K):
    return r*N*(1 - N/K)

def logistic_solution(t, r, K):
    return odeint(logistic_de, N0, t, (r, K))

params, _ = optim.curve_fit(logistic_solution, df_yeast['td'], df_yeast['cd'], p0=parsic);
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
ValueError: object too deep for desired array

---------------------------------------------------------------------------
error                                     Traceback (most recent call last)
<ipython-input-94-2a5a467cfa43> in <module>
----> 1 params, _ = optim.curve_fit(logistic_solution, df_yeast['td'], df_yeast['cd'], p0=parsic);

~/SageMath/local/lib/python3.9/site-packages/scipy/optimize/minpack.py in curve_fit(f, xdata, ydata, p0, sigma, absolute_sigma, check_finite, bounds, method, jac, **kwargs)
    782         # Remove full_output from kwargs, otherwise we're passing it in twice.
    783         return_full = kwargs.pop('full_output', False)
--> 784         res = leastsq(func, p0, Dfun=jac, full_output=1, **kwargs)
    785         popt, pcov, infodict, errmsg, ier = res
    786         ysize = len(infodict['fvec'])

~/SageMath/local/lib/python3.9/site-packages/scipy/optimize/minpack.py in leastsq(func, x0, args, Dfun, full_output, col_deriv, ftol, xtol, gtol, maxfev, epsfcn, factor, diag)
    420         if maxfev == 0:
    421             maxfev = 200*(n + 1)
--> 422         retval = _minpack._lmdif(func, x0, args, full_output, ftol, xtol,
    423                                  gtol, maxfev, epsfcn, factor, diag)
    424     else:

error: Result from function call is not a proper array of floats.

Solution

  • @hpaulj has pointed out the problem with the shape of the return value from logistic_solution and shown that fixing that eliminates the error that you reported.

    There is, however, another problem in the code. The problem does not generate an error, but it does result in an incorrect solution to your test problem (the logistic differential equation). By default, odeint expects the t argument of the function that computes the differential equations to be the second argument. Either change the order of the first two arguments of logistic_de, or add the argument tfirst=True to your call of odeint. The second option is a bit nicer, because it will allow you to use logistic_de with scipy.integrate.solve_ivp if you decide to try that function instead of odeint.