Search code examples
matplotlibtransparency

Matplotlib Savefig Function Draws Axes Over Self if Transparent is True


I have a program which crunches data and displays the results in a multi-axis figure. I have many different sets of figures, which I'm trying to generate into a report format. To save memory, I'm making a single figure instance and clearing at the end of each loop. Below is an example of the form:

import matplotlib.pyplot as plt
import numpy as np

def the_figure():
    #I want a figure that is persistent and accessible
    #So I make the figure an attribute of a function
    the_figure.fig = plt.figure()
    the_figure.axes = dict(
                    t_ax = plt.subplot2grid((6,2),(1,0)),
                    t_fit_ax = plt.subplot2grid((6,2),(1,1)),
                    o_ax = plt.subplot2grid((6,2),(2,0)),
                    o_fit_ax = plt.subplot2grid((6,2),(2,1)),
                    table = plt.subplot2grid((6,2),(3,0), 
                                    rowspan = 3, colspan = 2)
                    )

#A function which makes figures using the single figure function       
def Disp(i=5):
    try:
        the_figure.fig
    except:
        the_figure()

    pi = 3.141592653589793
    axes = the_figure.axes
    xs = np.linspace(-pi/2,pi/2)

    for n in range(i):
        for name,ax in axes.items():
            ax.plot(xs,np.sin(xs*n))

        the_figure.fig.savefig('test_folder\\bad'+str(n),transparent=True)
        the_figure.fig.savefig('test_folder\\good'+str(n),transparent=False)

        #Clear the axes for reuse, supposedly 
        for name,ax in axes.items():
            ax.cla()

When it's finished, the figures save with a transparent=True get an overlay of the curves from their loop AND curves from the previous loop. I have no idea what's going on.

With Transparency

Without Transparency


Solution

  • import matplotlib.pyplot as plt
    import numpy as np
    fig = plt.figure(1) # This is as persistent as assigning to whatever function
    def init_axes(fig):
       fig.clear()
       return dict(
                       t_ax = plt.subplot2grid((6,2),(1,0)),
                       t_fit_ax = plt.subplot2grid((6,2),(1,1)),
                       o_ax = plt.subplot2grid((6,2),(2,0)),
                       o_fit_ax = plt.subplot2grid((6,2),(2,1)),
                       table = plt.subplot2grid((6,2),(3,0), 
                                       rowspan = 3, colspan = 2)
                       )
    #A function which makes figures using the single figure       
    def Disp(i=5):
    
       pi = 3.141592653589793
       xs = np.linspace(-pi/2,pi/2)
    
       for n in range(i):
           axes = init_axes(fig)
           for name,ax in axes.items():
               ax.plot(xs,np.sin(xs*n))
    
           fig.savefig('bad'+str(n),transparent=True)
           fig.savefig('good'+str(n),transparent=False)