i am having the following information(dataframe) in python
product baskets scaling_factor
12345 475 95.5
12345 108 57.7
12345 2 1.4
12345 38 21.9
12345 320 88.8
and I want to run the following non-linear regression and estimate the parameters.
a ,b and c
Equation that i want to fit:
scaling_factor = a - (b*np.exp(c*baskets))
In sas we usually run the following model:(uses gauss newton method )
proc nlin data=scaling_factors;
parms a=100 b=100 c=-0.09;
model scaling_factor = a - (b * (exp(c*baskets)));
output out=scaling_equation_parms
parms=a b c;
is there a similar way to estimate the parameters in Python using non linear regression, how can i see the plot in python.
Agreeing with Chris Mueller, I'd also use scipy
but scipy.optimize.curve_fit
The code looks like:
###the top two lines are required on my linux machine
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.pyplot import cm
import numpy as np
from scipy.optimize import curve_fit #we could import more, but this is what we need
###defining your fitfunction
def func(x, a, b, c):
return a - b* np.exp(c * x)
###OP's data
baskets = np.array([475, 108, 2, 38, 320])
scaling_factor = np.array([95.5, 57.7, 1.4, 21.9, 88.8])
###let us guess some start values
initialGuess=[100, 100,-.01]
guessedFactors=[func(x,*initialGuess ) for x in baskets]
###making the actual fit
popt,pcov = curve_fit(func, baskets, scaling_factor,initialGuess)
#one may want to
print popt
print pcov
###preparing data for showing the fit
fittedData=[func(x, *popt) for x in basketCont]
###preparing the figure
fig1 = plt.figure(1)
###the three sets of data to plot
ax.plot(baskets,scaling_factor,linestyle='',marker='o', color='r',label="data")
ax.plot(baskets,guessedFactors,linestyle='',marker='^', color='b',label="initial guess")
ax.plot(basketCont,fittedData,linestyle='-', color='#900000',label="fit with ({0:0.2g},{1:0.2g},{2:0.2g})".format(*popt))
ax.legend(loc=0, title="graphs", fontsize=12)
###putting the covariance matrix nicely
tab= [['{:.2g}'.format(j) for j in i] for i in pcov]
the_table = plt.table(cellText=tab,
colWidths = [0.2]*3,
loc='upper right', bbox=[0.483, 0.35, 0.5, 0.25] )
###putting the plot