Search code examples
pythonscipyscipy-optimize

Apply a weighting to a 4 parameter regression curvefit


The below code generates a plot and 4PL curve fit, but the fit is poor at lower values. This error can usually be addressed by ading a 1/y^2 weighting, but I dont know how to do it in this instance. Adding sigma=1/Y_data**2 to the fit just makes it worse.

from scipy.optimize import curve_fit
import matplotlib.pyplot as plt

def fourPL(x, A, B, C, D):
    return ((A-D) / (1.0 + np.power(x / C, B))) + D

X_data = np.array([700,200,44,11,3,0.7,0.2,0])
Y_data = np.array([600000,140000,30000,8000,2100,800,500,60])


popt, pcov = curve_fit(fourPL, X_data, Y_data)

fig, ax = plt.subplots()    
ax.scatter(X_data, Y_data, label='Data')
X_curve = np.linspace(min(X_data[np.nonzero(X_data)]), max(X_data), 5000)
Y_curve = fourPL(X_curve, *popt)
ax.plot(X_curve, Y_curve)

ax.set_xscale('log')
ax.set_yscale('log')

plt.show()

Plot


Solution

  • Don't add inverse square weights; fit in the log domain. Always add bounds. And in this case, curve_fit doesn't do a very good job; consider instead minimize.

    import numpy as np
    from scipy.optimize import curve_fit, minimize
    import matplotlib.pyplot as plt
    
    
    def fourPL(x: np.ndarray, a: float, b: float, c: float, d: float) -> np.ndarray:
        return (a - d)/(1 + (x / c)**b) + d
    
    
    def estimated(x: np.ndarray, a: float, b: float, c: float, d: float) -> np.ndarray:
        return np.log(fourPL(x, a, b, c, d))
    
    
    def sqerror(abcd: np.ndarray) -> float:
        y = np.log(fourPL(x_data, *abcd)) - np.log(y_data)
        return y.dot(y)
    
    
    x_data = np.array([700, 200, 44, 11, 3, 0.7, 0.2, 0])
    y_data = np.array([600000, 140000, 30000, 8000, 2100, 800, 500, 60])
    guess = (500, 1.05, 1e6, 1e9)
    bounds = np.array((
        (1, 0.1, 1, 0),
        (np.inf, 10, np.inf, np.inf),
    ))
    
    popt, _ = curve_fit(
        f=estimated, xdata=x_data, ydata=np.log(y_data), p0=guess,
        bounds=bounds,
    )
    print('popt:', popt)
    result = minimize(
        fun=sqerror, x0=guess, bounds=bounds.T, tol=1e-9,
    )
    assert result.success
    print('minimize x:', result.x)
    
    x_curve = 10**np.linspace(-1, 3, 1000)
    
    fig, ax = plt.subplots()
    ax.scatter(x_data, y_data, label='Data')
    ax.plot(x_curve, fourPL(x_curve, *popt), label='curve_fit')
    ax.plot(x_curve, fourPL(x_curve, *result.x), label='minimize')
    ax.plot(x_curve, fourPL(x_curve, *guess), label='guess')
    ax.set_xscale('log')
    ax.set_yscale('log')
    ax.legend()
    plt.show()
    

    fits