Search code examples
pythonscikit-learnlogistic-regression

logistic like curve fitting using machine learning


I asked this question in data science threads, but didn't get an answer. Hence posting here.

I have a set of points of a function k(x). I am trying to do some curve fitting to find the exact k(x) function. It seems that the data points fit to a logistic like curve only a little shifted and stressed.

So far I have tried polynomial regression, but I don't feel the fitting is correct. I have attached a snap of the fitted curve here.

So my question is, is logistic regression only used in classification tasks? Or can it be used for curve fitting?

If not what are the other available techniques to fit a logistic like curve to a set of data points? Polynomial regression

EDIT

Following is the code. (x,y) are the datapoints.

import matplotlib.pyplot as plt 
import numpy as np
from sklearn.linear_model import Ridge
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import PolynomialFeatures
from sklearn.metrics import mean_squared_error
from sklearn.linear_model import LogisticRegression

x = np.array([0.3, 0.4, 0.5, 0.6, 0.65, 0.67, 0.8])
y = np.array([-936, -892, -178.33, -50.7, -65.7, -70.44, -9])

degree = 5

model = make_pipeline(PolynomialFeatures(degree), Ridge(alpha=1E-10, fit_intercept=False))
# model = LogisticRegression(random_state=0, solver='lbfgs')
model.fit(x[:, None], y)
ridge = model.named_steps['ridge']
print(ridge.coef_)
coef = ridge.coef_

poly_mse = mean_squared_error(model.predict(x[:, None]), y)
print 'RMSE', math.sqrt(poly_mse)

predictions = model.predict(np.arange(0.28,0.85,0.0001).reshape(-1, 1))

plt.plot(x, y, 'ro', label='Measurement Data')
plt.plot(np.arange(0.28,0.85,0.0001), predictions, label="Best Fit: %.2f$X^4$ %.2f$X^3$ + %.2f$X^2$ + %.2fX %.2f" % (coef[-1],coef[-2],coef[-3],coef[-4],coef[-5]))
plt.title('K vs Barium Proportion (X) at 10kHz')
plt.xlabel('Barium Proportion (X)')
plt.ylabel('K')
plt.show()

Solution

  • Here is a graphical fitter using your data and a simple three-parameter logistic type equation, the fit seems fairly good to me.

    plot

    import numpy, scipy, matplotlib
    import matplotlib.pyplot as plt
    from scipy.optimize import curve_fit
    import warnings
    
    xData = numpy.array([0.3, 0.4, 0.5, 0.6, 0.65, 0.67, 0.8])
    yData = numpy.array([-936.0, -892.0, -178.33, -50.7, -65.7, -70.44, -9.0])
    
    
    def func(x, a, b, c): # Logistic B equation from zunzun.com
        return a / (1.0 + numpy.power(x/b, c))
    
    
    # these are the same as the scipy defaults
    initialParameters = numpy.array([1.0, 1.0, 1.0])
    
    # curve fit the test data, ignoring warning due to initial parameter estimates
    warnings.filterwarnings("ignore")
    fittedParameters, pcov = curve_fit(func, xData, yData, initialParameters)
    
    modelPredictions = func(xData, *fittedParameters) 
    
    absError = modelPredictions - yData
    
    SE = numpy.square(absError) # squared errors
    MSE = numpy.mean(SE) # mean squared errors
    RMSE = numpy.sqrt(MSE) # Root Mean Squared Error, RMSE
    Rsquared = 1.0 - (numpy.var(absError) / numpy.var(yData))
    
    print('Parameters:', fittedParameters)
    print('RMSE:', RMSE)
    print('R-squared:', Rsquared)
    
    print()
    
    
    ##########################################################
    # graphics output section
    def ModelAndScatterPlot(graphWidth, graphHeight):
        f = plt.figure(figsize=(graphWidth/100.0, graphHeight/100.0), dpi=100)
        axes = f.add_subplot(111)
    
        # first the raw data as a scatter plot
        axes.plot(xData, yData,  'D')
    
        # create data for the fitted equation plot
        xModel = numpy.linspace(min(xData), max(xData))
        yModel = func(xModel, *fittedParameters)
    
        # now the model as a line plot
        axes.plot(xModel, yModel)
    
        axes.set_xlabel('X Data') # X axis data label
        axes.set_ylabel('Y Data') # Y axis data label
    
        plt.show()
        plt.close('all') # clean up after using pyplot
    
    graphWidth = 800
    graphHeight = 600
    ModelAndScatterPlot(graphWidth, graphHeight)