Search code examples
pythoneventskeypressinteractive

interactive fit with python


I want to fit a set of data points with a polynomial function (what I normally do with numpy.polyfit), but I want to let the user choose interactively the degree of the polynomial. Here's an example of what I'm trying to do:

import sys
import numpy as np
import matplotlib.pyplot as plt
from pylab import *
from sklearn.metrics import mean_squared_error
fig, ax = plt.subplots()
x = np.arange(1,10,0.2)
y = np.sin(x)
ax.plot(x,y,'o',color='orange',markeredgewidth=0.3,markeredgecolor='k')
ax.set_xlim(0,10)
ax.set_ylim(-1.1,1.1)
def press(event):
  #fig.clf()
  fig.canvas.draw_idle()
  sys.stdout.flush()
  deg = int(event.key)
  coeffs = np.polyfit(x,y,deg)
  p = np.poly1d(coeffs)
  rms = sqrt(mean_squared_error(y, p(x)))
  fig.text(0.8,1.02, 'rms='+str(round(rms,4)), rotation=0, color='k',transform=ax.transAxes)
  with open('prova.txt', 'w') as filehandle:
    filehandle.write('#Coefficients for a n= '+str(deg)+' polynomial fit\n\n')
    for listitem in coeffs:
      filehandle.write('%s\n' % listitem)
  ln = plt.plot(x,p(x),'-',color='green',linewidth=0.8,zorder=0)
  fig.canvas.draw()
cid = fig.canvas.mpl_connect('key_press_event', press)
plt.show()
fig.canvas.mpl_disconnect(cid)
f = open('prova.txt','r')
cfs = loadtxt('prova.txt', usecols=(0),comments='#')
print(cfs)

In this way it is effectively possible to fit the points, but the plots following the first one are overplotted. If I drop the '#' for fig.clf() the code updates the fit but cancel the dots.


Solution

  • At start create global variable with None

      txt = None
      ln = None
    

    And inside press() you can check if you already assigned plot and text and remove() it.

      global txt
      global ln
    
      if txt:
        txt.remove()
      txt = fig.text(0.8,1.02, 'rms='+str(round(rms,4)), rotation=0, color='k',transform=ax.transAxes)
    
      if ln:
        ln[0].remove()
      ln = plt.plot(x,p(x),'-',color='green',linewidth=0.8,zorder=0)
    
      fig.canvas.draw()
    

    Or you can set new data and text

    global txt
    global ln
    
    if txt:
        # replace text
        txt.set_text('rms={}'.format(round(rms,4)))
    else:
        # create first time
        txt = fig.text(0.8, 1.02, 'rms={}'.format(round(rms,4)), rotation=0, color='k', transform=ax.transAxes)
    
    if ln:
        # replace data
        ln[0].set_data(x, p(x))
    else:              
        # create first time
        ln = plt.plot(x, p(x), '-', color='green', linewidth=0.8, zorder=0)
    
    fig.canvas.draw()
    

    Full code

    import sys
    import numpy as np
    import matplotlib.pyplot as plt
    from pylab import *
    from sklearn.metrics import mean_squared_error
    
    # --- functions ---
    
    def press(event):
        global txt
        global ln
    
    
        #fig.clf()
        fig.canvas.draw_idle()
        sys.stdout.flush()
    
        deg = int(event.key)
    
        coeffs = np.polyfit(x, y, deg)
        p = np.poly1d(coeffs)
        rms = sqrt(mean_squared_error(y, p(x)))
    
        with open('prova.txt', 'w') as filehandle:
            filehandle.write('#Coefficients for a n= {} polynomial fit\n\n'.format(deg))
            for listitem in coeffs:
                filehandle.write('{}\n'.format(listitem))
    
        if txt:
            txt.remove()
        txt = fig.text(0.8, 1.02, 'rms={}'.format(round(rms,4)), rotation=0, color='k', transform=ax.transAxes)
    
        if ln:
            ln[0].remove()
        ln = plt.plot(x, p(x), '-', color='green', linewidth=0.8, zorder=0)
    
        #if txt:
        #    txt.set_text('rms={}'.format(round(rms,4)))
        #else:
        #    txt = fig.text(0.8, 1.02, 'rms={}'.format(round(rms,4)), rotation=0, color='k', transform=ax.transAxes)
    
        #if ln:
        #    ln[0].set_data(x, p(x))
        #else:              
        #    ln = plt.plot(x, p(x), '-', color='green', linewidth=0.8, zorder=0)
    
        fig.canvas.draw()
    
    # --- main ---
    
    txt = None
    ln = None
    
    fig, ax = plt.subplots()
    x = np.arange(1, 10, 0.2)
    y = np.sin(x)
    ax.plot(x, y, 'o', color='orange', markeredgewidth=0.3, markeredgecolor='k')
    
    ax.set_xlim(0, 10)
    ax.set_ylim(-1.1, 1.1)
    cid = fig.canvas.mpl_connect('key_press_event', press)
    plt.show()
    
    #fig.canvas.mpl_disconnect(cid)
    cfs = loadtxt('prova.txt', usecols=0, comments='#')
    print(cfs)