Search code examples
pythonmatplotlibseaborncolorbar

How to avoid multiple colorbars when looping seaborn heatmap


I made 3 plots in a for loop with the following code. But the first figure I create has 3 colorbars, the 2nd has 2 and the 1st has 1. So it seems the previous plot adds a colorbar to the current plot. How can I avoid this?

for f in files:
    #print(f)
    roi_signals = pd.read_csv(f, sep='\t')
   
    fig = sns.heatmap(roi_signals)
    fig_name = f.replace('.txt', '.png')
    plt.savefig(fig_name)

Solution

  • By default, sns.heatmap plots onto the existing Axes and allocates space for a colorbar:

    This will draw the heatmap into the currently-active Axes if none is provided to the ax argument. Part of this Axes space will be taken and used to plot a colormap, unless cbar is False or a separate Axes is provided to cbar_ax.


    The simplest solution is to clear the figure each iteration with plt.clf:

    for f in files:
        roi_signals = pd.read_csv(f, sep='\t')
    
        sns.heatmap(roi_signals)
        plt.savefig(f.replace('.txt', '.png')
        plt.clf() # clear figure before next iteration
    

    Or specify cbar_ax to overwrite the previous iteration's colorbar:

    fig = plt.figure()
    for i, f in enumerate(files):
        roi_signals = pd.read_csv(f, sep='\t')
    
        cbar_ax = fig.axes[-1] if i else None # retrieve previous cbar_ax (if exists)
        sns.heatmap(roi_signals, cbar_ax=cbar_ax)
        plt.savefig(f.replace('.txt', '.png')
    

    Or just create a new figure per iteration, but this is not recommended for many iterations:

    for f in files:
        roi_signals = pd.read_csv(f, sep='\t')
    
        fig = plt.figure() # create new figure each iteration
        sns.heatmap(roi_signals)
        plt.savefig(f.replace('.txt', '.png')