Search code examples
matplotlibsubplotcolor-palette

How to shared color palette between multiple subplots?


I have the following figure:

enter image description here

The figure is composed by the following code snippet:

fig = plt.figure(constrained_layout=True)
grid = fig.add_gridspec(2, 2)

ax_samples_losses = fig.add_subplot(grid[0, 0:])
ax_samples_losses.set_title('Avg. loss per train sample (epoch 0 excluded)')
for sample_idx, sample_avg_train_loss_history in enumerate(samples_avg_train_loss_history):
    ax_samples_losses.plot(sample_avg_train_loss_history, label='Sample ' + str(sample_idx))
ax_samples_losses.set_title('Avg. loss per train sample (epoch 0 excluded)')
ax_samples_losses.set_xlabel('Epoch')
ax_samples_losses.set_ylabel('Sample avg. loss')
ax_samples_losses.set_xticks(range(1, epochs))
ax_samples_losses.tick_params(axis='x', rotation=90)
ax_samples_losses.yaxis.set_ticks(np.arange(0, np.max(samples_avg_train_loss_history), 0.25))
ax_samples_losses.tick_params(axis='both', which='major', labelsize=6)
plt.legend(bbox_to_anchor=(1, 1), prop={'size': 6}) #loc="upper left"
# fig.legend(...)

ax_patches_per_sample = fig.add_subplot(grid[1, 0])
#for sample_idx, sample_patches_count in enumerate(samples_train_patches_count):
#    ax_patches_per_sample.bar(sample_patches_count, label='Sample ' + str(sample_idx))
ax_patches_per_sample.bar(range(0, len(samples_train_patches_count)), samples_train_patches_count, align='center')
ax_patches_per_sample.set_title('Patches per sample')
ax_patches_per_sample.set_xlabel('Sample')
ax_patches_per_sample.set_ylabel('Patch count')
ax_patches_per_sample.set_xticks(range(0, len(samples_train_patches_count)))
ax_patches_per_sample.yaxis.set_ticks(np.arange(0, np.max(samples_train_patches_count), 20))
ax_patches_per_sample.tick_params(axis='both', which='major', labelsize=6)

where

  • samples_train_patches_count is a simple list with the number of patches per sampled image
  • samples_avg_train_loss_history is a list of lists in the shape samples, epochs (so if viewed as a matrix every row will be a sample and every column will be the loss of that sample over time)

I do believe I need to do both

  • shared legend
  • shared color palette

The shared legend can be done by using get_legend_handles_labels(). However I do not know how to share colors. Both subplots describe different properties of the same thing - the samples. In short I would like to have Patches per sample subplot have all the colors Avg. loss per train sample (epoch 0 excluded) uses.


Solution

  • The first plot is using standard matplotlib Tab10 discrete color map. We can create a cycler over this colormap, and set one by one the color of each bar:

    import matplotlib.pyplot as plt
    import matplotlib.cm as cm
    from matplotlib.gridspec import GridSpec
    import numpy as np
    from itertools import cycle
    
    # create a cycler to continously loop over a discrete colormap
    cycler = cycle(cm.tab10.colors)
    
    N = 10
    x = np.arange(N).astype(int)
    y = np.random.uniform(5, 15, N)
    
    f = plt.figure()
    gs = GridSpec(2, 4)
    ax0 = f.add_subplot(gs[0, :-1])
    ax1 = f.add_subplot(gs[1, :-1])
    ax2 = f.add_subplot(gs[:, -1])
    
    for i in x:
        ax0.plot(x, np.exp(-x / (i + 1)), label="Sample %s" % (i + 1))
    h, l = ax0.get_legend_handles_labels()
    
    ax1.bar(x, y)
    for p in ax1.patches:
        p.set_facecolor(next(cycler))
    
    ax2.axis(False)
    ax2.legend(h, l)
    plt.tight_layout()
    

    enter image description here

    EDIT to accommodate comment. To avoid repetitions you should use a colormap. Matplotlib offers many colormaps. Alternatively, you can also create your own.

    import matplotlib.pyplot as plt
    import matplotlib.cm as cm
    from matplotlib.gridspec import GridSpec
    import numpy as np
    from itertools import cycle
    
    N = 50
    # create a cycler to continously loop over a discrete colormap
    colors = cm.viridis(np.linspace(0, 1, N))
    
    x = np.arange(N).astype(int)
    y = np.random.uniform(5, 15, N)
    
    f = plt.figure()
    gs = GridSpec(2, 4)
    ax0 = f.add_subplot(gs[0, :-1])
    ax1 = f.add_subplot(gs[1, :-1])
    ax2 = f.add_subplot(gs[:, -1])
    
    ax1.bar(x, y)
    
    for i in x:
        c = next(cycler)
        ax0.plot(x, np.exp(-x / (i + 1)), color=c, label="Sample %s" % (i + 1))
        ax1.patches[i].set_facecolor(c)
    h, l = ax0.get_legend_handles_labels()
    
    ax2.axis(False)
    ax2.legend(h, l)
    plt.tight_layout()
    

    enter image description here