Search code examples
pythonloopsoptimizationmodel-fitting

Python fitting : optimize loop


I am trying to fit multiple gaussians to a given data and this part of the program is using about 3 GB memory when reaching at 500th model and I need to fit a total of ~2000 models. Here is a simplied version of my program with randomly generated data, which will not produce good fit, but it explains the issue of time :

import sys
sys.setrecursionlimit(5000)
import matplotlib.pyplot as plt
import numpy as np
import random
from random import uniform
x=[random.uniform(2200.,3100.) for p in range(0, 1000)]
y=[random.uniform(1.,1000.) for p in range(0, 1000)]

import sherpa.ui as ui
import numpy as np
ui.load_arrays(1,x,y) # 1 is the first data 
d1=ui.get_data()
d1.staterror=0.002*d1.y # define error on y just for plotting purpose, not required for fit
ui.plot_data()
ui.set_stat("leastsq") # leasr square method for fit
ui.set_model(ui.powlaw1d.pow1) # fit powerlaw.. pow1 is the shortcut name 
# ui.show_all() will show you all the parameters for the model
pow1.ref=2500
ui.fit()
# fitting  templates
x2=[random.uniform(2200.,3100.) for p in range(0, 1000)]
y2=[random.uniform(1.,1000.) for p in range(0, 1000)]

model1="pow1" # initiliaze the model for fitting all the gaussians
sign="+"
sigma=45. 
g_pos=x2 
g_ampl=[] # we will store the fit value here



ui.freeze(model1) # freeze the powerlaw 
for n in range(1,1000): # this excludes the upper limit
        ui.create_model_component("gauss1d","g{}".format(n))
        ui.set_par("g{}.pos".format(n),x2[n],frozen=True)
        ui.set_par("g{}.ampl".format(n),y2[n])
        ui.set_par("g{}.fwhm".format(n),sigma,frozen=True)
        model1=model1+sign+"g{}".format(n)
        if y2[n] == 0.:
           g_ampl.append(0.) # list zero amplitude for this model
        else:
           g=ui.create_model_component("gauss1d","g{}".format(n)) # do this to store g_ampl of this model only
           ui.set_source(model1) # overwriting with actual model
           ui.fit()
           ui.fit()
           ui.fit()
           g_ampl.append(g.ampl.val)
        ui.freeze(model1) # freeze the model and go to the next gaussian

I am unable to figure out a way to optimize this part to make it efficient and less time consuming. Any ideas to help me make it run faster would be appreciated.


Solution

  • The problem with your code is that it seems to unnecessarily store all the data you want to fit. A better solution is to only store the result of the fit. I don't really know sherpa very much. Here is a solution with scipy.optimize

    import numpy as np
    import matplotlib.pyplot as plt
    from scipy.optimize import curve_fit
    import glob, os
    plt.ion()
    
    def gaussian_func(x, a, x0, sigma,c):
        return a * np.exp(-(x-x0)**2/(2*sigma**2)) + c
    
    # This is creates two files with random data
    # Don't use for actual program
    xdata = np.linspace(0, 4, 50)
    ydata = gaussian_func(xdata, 2.5, 1.3, 0.5,0) + 0.2 * np.random.normal(size=len(xdata))
    np.savetxt('example.dat',np.array([xdata,ydata]).T)
    ydata = gaussian_func(xdata, 2.5, 2.3, 0.5,0) + 0.2 * np.random.normal(size=len(xdata))
    np.savetxt('example2.dat',np.array([xdata,ydata]).T)
    
    # Create your list of files with the data
    # This examples just loads all files with .dat extension
    filelist = glob.glob("*.dat") 
    print(filelist)
    
    results = []
    
    for file in filelist:
        data = np.loadtxt(file)
        xdata = data[:,0]
        ydata = data[:,1]
        # if the error of the points is included in your file
        # exchange the following line to sigma = data[:,2]
        sigma = 0*ydata+0.2
        initial_guess = [2,1,1,0]
        popt, pcov = curve_fit(gaussian_func, xdata, ydata,p0=initial_guess,sigma=sigma)
        results.append({"filename":file,"parameters":popt,"covariance matrix":pcov})
    
        # This plots the result
        # Should be commented out for the large dataset
        plt.figure(1)
        plt.clf()
        plt.errorbar(xdata,ydata,sigma,fmt='ko')
        xplot = np.linspace(0,4,100)
        plt.plot(xplot,gaussian_func(xplot,*popt),'r',linewidth=2)
        plt.draw()