Search code examples
pythonmatplotlibcolorbar

matplotlib change colorbar height within own axes


I'm currently trying to create a stackplot of graphs, of which my first two have colorbars. To do this nicely, I'm using GridSpec to define two columns, with the second being much thinner and specifically for colorbars (or other out-of-plot things like legends).

grids  = gs.GridSpec(5, 2, width_ratios=[1, 0.01])
ax1    = fig.add_subplot(grids[0, 0])
cax1   = fig.add_subplot(grids[0, 1])

The problem is that for these top two plots, the ticklabels of my colorbar overlap slightly, due to the fact that I've got zero horizontal space between my plots.

I know that there are ways to control the height of the colorbar, but they seem to rely on the colorbar making it's own axes by borrowing space from the parent axes. I was wondering if there was any way to control how much space (or specifically, height) the colorbar takes up when you use the cax kwarg

fig.colorbar(im1, cax=cax1, extend='max')

or if it defaults (immutably) to take up the entire height of the axes given to it.

Thanks!

EDIT: Here's an image of the issue I'm struggling with.

enter image description here

If I could make the second slightly shorter, or shift the upper one slightly up then it wouldn't be an issue. Unfortunately since I've used GridSpec (which has been amazing otherwise) I'm constrained to the limits of the axes.


Solution

  • I don't think there is any way to ask colorbar to not fill the whole cax. However, it is fairly trivial to shrink the size of the cax before (or after actually) plotting the colorbar.

    I wrote this small function:

    def shrink_cbar(ax, shrink=0.9):
        b = ax.get_position()
        new_h = b.height*shrink
        pad = (b.height-new_h)/2.
        new_y0 = b.y0 + pad
        new_y1 = b.y1 - pad
        b.y0 = new_y0
        b.y1 = new_y1
        ax.set_position(b)
    

    which can be used like so:

    fig = plt.figure()
    grids  = gs.GridSpec(2, 2, width_ratios=[1, 0.01])
    ax1    = fig.add_subplot(grids[0, 0])
    cax1   = fig.add_subplot(grids[0, 1])
    
    ax2    = fig.add_subplot(grids[1, 0])
    cax2   = fig.add_subplot(grids[1, 1])
    shrink_cbar(cax2, 0.75)
    

    enter image description here