The below code generates a plot and 4PL curve fit, but the fit is poor at lower values. This error can usually be addressed by ading a 1/y^2 weighting, but I dont know how to do it in this instance. Adding sigma=1/Y_data**2
to the fit just makes it worse.
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt
def fourPL(x, A, B, C, D):
return ((A-D) / (1.0 + np.power(x / C, B))) + D
X_data = np.array([700,200,44,11,3,0.7,0.2,0])
Y_data = np.array([600000,140000,30000,8000,2100,800,500,60])
popt, pcov = curve_fit(fourPL, X_data, Y_data)
fig, ax = plt.subplots()
ax.scatter(X_data, Y_data, label='Data')
X_curve = np.linspace(min(X_data[np.nonzero(X_data)]), max(X_data), 5000)
Y_curve = fourPL(X_curve, *popt)
ax.plot(X_curve, Y_curve)
ax.set_xscale('log')
ax.set_yscale('log')
plt.show()
Don't add inverse square weights; fit in the log domain. Always add bounds. And in this case, curve_fit
doesn't do a very good job; consider instead minimize
.
import numpy as np
from scipy.optimize import curve_fit, minimize
import matplotlib.pyplot as plt
def fourPL(x: np.ndarray, a: float, b: float, c: float, d: float) -> np.ndarray:
return (a - d)/(1 + (x / c)**b) + d
def estimated(x: np.ndarray, a: float, b: float, c: float, d: float) -> np.ndarray:
return np.log(fourPL(x, a, b, c, d))
def sqerror(abcd: np.ndarray) -> float:
y = np.log(fourPL(x_data, *abcd)) - np.log(y_data)
return y.dot(y)
x_data = np.array([700, 200, 44, 11, 3, 0.7, 0.2, 0])
y_data = np.array([600000, 140000, 30000, 8000, 2100, 800, 500, 60])
guess = (500, 1.05, 1e6, 1e9)
bounds = np.array((
(1, 0.1, 1, 0),
(np.inf, 10, np.inf, np.inf),
))
popt, _ = curve_fit(
f=estimated, xdata=x_data, ydata=np.log(y_data), p0=guess,
bounds=bounds,
)
print('popt:', popt)
result = minimize(
fun=sqerror, x0=guess, bounds=bounds.T, tol=1e-9,
)
assert result.success
print('minimize x:', result.x)
x_curve = 10**np.linspace(-1, 3, 1000)
fig, ax = plt.subplots()
ax.scatter(x_data, y_data, label='Data')
ax.plot(x_curve, fourPL(x_curve, *popt), label='curve_fit')
ax.plot(x_curve, fourPL(x_curve, *result.x), label='minimize')
ax.plot(x_curve, fourPL(x_curve, *guess), label='guess')
ax.set_xscale('log')
ax.set_yscale('log')
ax.legend()
plt.show()