How can I solve the curve_fit? ValueError: object too deep for desired array Error : Result from function call is not a proper array of floats.
### Import Libraries
import numpy as np
from scipy.optimize import curve_fit
### Define Function
def Func(vars, C1, C2):
(X, Y) = vars
Z1 = (C1*Y**2) / (1+(1-(C1*Y)**2)**0.5)
Z2 = (C2*X**2) / (1+(1-(C2*X)**2)**0.5)
return Z1 + Z2
### Y Data
xL = np.linspace(0.0, 10, 11).flatten() ## Sub
yL = np.linspace(0.0, 100, 101).flatten() ## Main
X, Y = np.meshgrid(abs(xL), abs(yL))
### Coefficient
C1 = 0.002
C2 = 0.005
### Calculate : Original and Noise Data
Z_original = Func((X, Y), C1, C2)
Z_noise = np.random.normal(size=(len(xL)*len(yL)), scale=0.5)
Z_noise.resize(len(yL), len(xL))
Z_noise = Z_original + Z_noise
### Curve_Fit """ ???????????????????????? """
p0 = (0.002, 0.005)
popt, pcov = curve_fit(Func, (X,Y), Z_noise, p0)
Z_curvefit = Func((X,Y), *popt)
scipy.optimize.curve_fit
can be used to fit 2d data, but the dependent data (output of the model function) must still be 1d (as stated in the scipy doc).
A solution is to use np.ravel() to flatten the return value of func
:
import numpy as np
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt
def func(data, c1, c2):
(X, Y) = data
Z1 = (c1*Y**2) / (1+(1-(c1*Y)**2)**0.5)
Z2 = (c2*X**2) / (1+(1-(c2*X)**2)**0.5)
return (Z1 + Z2).ravel() # <-- add ravel() here
xL = np.linspace(0.0, 10, 11).flatten() ## Sub
yL = np.linspace(0.0, 100, 101).flatten() ## Main
X, Y = np.meshgrid(abs(xL), abs(yL))
c1 = 0.002
c2 = 0.005
# Original and noisy Data
Z_original = func((X, Y), c1, c2)
Z_noise = np.random.normal(size=(len(xL)*len(yL)), scale=0.5) # <-- no resizing necessary now
Z = Z_original + Z_noise
# Curve fiting
p0 = (0.002, 0.005)
popt, pcov = curve_fit(func, (X,Y), Z, p0)
Z_curvefit = func((X,Y), *popt)
# plot (reshape Z and Z_curvefit first)
fig, ax = plt.subplots(1, 1)
ax.imshow(Z.reshape(101, 11),
cmap=plt.cm.jet, origin='lower',
extent=(X.min(), X.max(), Y.min(), Y.max()))
ax.contour(X, Y, Z_curvefit.reshape(101, 11), 8, colors='w')
plt.show()