Search code examples
pythonmatplotlibplotseaborncolorbar

How to prevent colorbar from moving up/down as heatmap height changes? Matplotlib/seaborn


I am generating a heatmap dynamically and the number of categories on the y and x axes may be different each time. How can I position the colorbar next to the heatmap so that it is always anchored at the very top (basically first row of the heatmap) regardless of the height of the figure?

Here's what's happening:

enter image description here

enter image description here

I have so far managed to set the colorbar height and width using add_axes so that these remain constant whatever the figure size. However I am struggling to set its y-axis position dynamically. Minimal example below:

import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable, axes_size

data = np.random.rand(5,5)

# Dynamic figure parameters.
topmargin = 0.1 
bottommargin = 0.1 
# Square height
cat_height = 0.4 
# Number of y-axis points.
n=data.shape[0]

leftmargin = 0.1
rightmargin = 0.1
# Square width.
cat_width = 0.5
# Number of x-axis points.
m=data.shape[1]

# Dynamic figure height.
figheight = topmargin + bottommargin + (n+1)*cat_height
# Dynamic figure width.
figwidth = leftmargin + rightmargin + (m+1)*cat_width

fig, ax = plt.subplots(figsize=(figwidth, figheight))

# [x, y, width, height]
cbar_ax = fig.add_axes([0.93, 0.33, 0.13/m, 2.75/n])

# Plot the heatmap.
ax = sns.heatmap(data, ax=ax, cmap='coolwarm', cbar_ax=cbar_ax, cbar=True)

plt.show()

Basically colorbar is moving up/down when the figure height changes but I would like it anchored at the top of the figure every time.


Solution

  • You could simply calculate the bottom coordinates based on the height of your cbar and the top of the heatmap axes

    cbar_ax = fig.add_axes([0.93, 0.88-2.75/n, 0.13/m, 2.75/n])
    

    0.88 is the top of the top subplot with the default margins (see plt.rcParams['figure.subplot.top']).

    However, for this kind of things, I would use a GridSpec to define a grid of axes with configurable size ratios (adjust the height_ratios to suit your needs):

    gs = matplotlib.gridspec.GridSpec(2,2, height_ratios=[3,n-3], width_ratios=[20,1])
    fig = plt.figure(figsize=(figwidth, figheight))
    ax = fig.add_subplot(gs[:,0])
    cbar_ax = fig.add_subplot(gs[0,1])