Search code examples
matplotlibanimationlegendcelluloid

Fix a legend in an animation created by celluloid


I want to animate the process of finding the minimum point of a function by different gradient descent optimization methods. For this purpose, I am using matplotlib and celluloid packages. The problem is that it is not possible to fix the legend of the plot in animation and in each loop a new legend is added below the previous legend as you can see in the figure below. is there any way to fix the legend and avoid this problem?

from celluloid import Camera
fig,ax = plt.subplots(1, 1,figsize=(10, 10))
camera = Camera(fig)
for i in range(path1.shape[1])
  ax.contour(x_mesh, y_mesh, z, levels=np.logspace(0, 5, 35), norm=LogNorm(), cmap=plt.cm.jet)
  ax.plot(*minima_, 'r*', markersize=18)

  line, = ax.plot([], [], 'k', label='Simple SGD', lw=2)
  point, = ax.plot([], [], 'ko')
  line.set_data(path1[::,:i])
  point.set_data(path1[::,i-1:i])

  line, = ax.plot([], [], 'r', label='SGD with momentum', lw=2)
  point, = ax.plot([], [], 'ro')
  line.set_data(*path2[::,:i])
  point.set_data(*path2[::,i-1:i])

  line, = ax.plot([], [], 'g', label='SGD with Nesterov', lw=2)
  point, = ax.plot([], [], 'go')
  line.set_data(*path3[::,:i])
  point.set_data(*path3[::,i-1:i])

  line, = ax.plot([], [], 'b', label='SGD with Adagrad', lw=2)
  point, = ax.plot([], [], 'bo')
  line.set_data(*path4[::,:i])
  point.set_data(*path4[::,i-1:i])

  line, = ax.plot([], [], 'c', label='SGD with Adadelta', lw=2)
  point, = ax.plot([], [], 'co')
  line.set_data(*path5[::,:i])
  point.set_data(*path5[::,i-1:i]) 

  line, = ax.plot([], [], 'm', label='SGD with RMSprob', lw=2)
  point, = ax.plot([], [], 'mo')
  line.set_data(*path6[::,:i])
  point.set_data(*path6[::,i-1:i])

  line, = ax.plot([], [], 'y', label='SGD with Adam', lw=2)
  point, = ax.plot([], [], 'yo')
  line.set_data(*path7[::,:i])
  point.set_data(*path7[::,i-1:i])

  line, = ax.plot([], [], 'y', label='SGD with Adamax', lw=2)
  point, = ax.plot([], [], 'y*')
  line.set_data(*path8[::,:i])
  point.set_data(*path8[::,i-1:i])

  line, = ax.plot([], [], 'k', label='SGD with Nadam', lw=2)
  point, = ax.plot([], [], 'kp')
  line.set_data(*path9[::,:i])
  point.set_data(*path9[::,i-1:i])

  line, = ax.plot([], [], 'r', label='SGD with AMSGrad', lw=2)
  point, = ax.plot([], [], 'rD')
  line.set_data(*path10[::,:i])
  point.set_data(*path10[::,i-1:i])

  ax.legend(loc='upper left') 
  camera.snap()
animation = camera.animate()
animation.save('2D_animation_overlap.gif', writer='imagemagick')

enter image description here


Solution

  • The best practice here would be to create a custom legend instead of automatically generating a legend, in this case that could be done by

    import matplotlib.pyplot as plt
    from matplotlib.lines import Line2D
    
    labels = ['Single SGD', 'SGD with momentum', 'SGD with Nesterov', 
              'SGD with Adagrad', 'SGD with Adadelta', 'SGD with RMSprob', 'SGD with Adam', 
              'SGD with Adamax', 'SGD with Nadam', 'SGD with AMSgrad']
    colors = ['k', 'r', 'g', 'b', 'c', 'm', 'y', 'y', 'k', 'r']
    handles = []
    for c, l in zip(colors, labels):
        handles.append(Line2D([0], [0], color = c, label = l))
    
    plt.legend(handles = handles, loc = 'upper left')
    

    which will give you a legend like this:

    enter image description here

    You don’t need to have any of this in the loop, you can do it before or after and it will still work. It will also work in the loop but redrawing the legend each time is unnecessary.

    It would also suffice to simply guard the legend creation with an if statement instead of creating the legend manually. I.e.

        # ...
        if i == 0:
            ax.legend(loc = 'upper left')
    

    But I would recommend against the practice of goading the automatic legend generation in favor of directly creating the legend.