Search code examples
pythonmatplotlibcolorbar

How to place multiple colorbars next to each other in the same axis automatically


I have a plot (example here below) where I need to put multiple colorbars in the same axis

Example plot

To do that at the moment I have to create a new axis for every colorbar with values defined manually like so

x_cbar_0, y_cbar_0, x_cbar_size, y_cbar_size     = 0.18, 0.05, 0.3, 0.02
x_cbar2_0, y_cbar2_0, x_cbar2_size, y_cbar2_size = 0.55, 0.05, 0.3, 0.02

ax_cbar = plt.gcf().add_axes([x_cbar_0, y_cbar_0, x_cbar_size, y_cbar_size])
ax_cbar_2 = plt.gcf().add_axes([x_cbar2_0, y_cbar2_0, x_cbar2_size, y_cbar2_size])
cbar_snow = plt.gcf().colorbar(cs_snow, cax=ax_cbar, orientation='horizontal',
 label='Snow')
cbar_rain = plt.gcf().colorbar(cs_rain, cax=ax_cbar_2, orientation='horizontal',
 label='Rain')

This is not really a portable solution as when something is changing, for example the map projection, the plot is slightly resized and I have to tune these numbers to manually place the colorbars without overlapping.

Isn't there a way to automatically create a new axis with the same width of the figure (I believe mpl_toolkits.axes_grid1.axes_divider should do that) and then split it into an arbitrary number of sub-axes which can then be used to place the colorbars?


Solution

  • As @Mr. T said above, a gridspec is a pretty good way. You may want to nest gridspecs for more complicated layouts, but:

    import matplotlib.pyplot as plt
    import numpy as np
    fig = plt.figure(constrained_layout=True)
    
    gs = fig.add_gridspec(2, 3, height_ratios=[1, 0.05], width_ratios=[1, 0.2, 1])
    
    ax = fig.add_subplot(gs[0, :])
    
    pc1 = ax.pcolormesh(np.random.randn(20, 20), cmap='viridis')
    pc2 = ax.pcolormesh(np.random.randn(20, 20), cmap='RdBu_r')
    
    cax1 = fig.add_subplot(gs[1, 0])
    fig.colorbar(pc1, cax=cax1, orientation='horizontal')
    
    cax2 = fig.add_subplot(gs[1, -1])
    fig.colorbar(pc2, cax=cax2, orientation='horizontal', extend='max')
    
    plt.show()
    

    enter image description here