Search code examples
pythonseabornheatmapsubplotcolorbar

Plotting multiple seaborn heatmaps with individual color bar


Is it possible to plot multiple seaborn heatmaps into a single figure, with a shared yticklabel, and individual color bars, like the figure below?

enter image description here

What I can do is to plot the heatmaps individually, using the following code:

#Figure 1

plt.figure()
sns.set()
comp = sns.heatmap(df, cmap="coolwarm", linewidths=.5, xticklabels=True, yticklabels=True, cbar_kws={"orientation": "horizontal", "label": "Pathway completeness", "pad": 0.004})
comp.set_xticklabels(comp.get_xticklabels(), rotation=-90)
comp.xaxis.tick_top() # x axis on top
comp.xaxis.set_label_position('top')
cbar = comp.collections[0].colorbar
cbar.set_ticks([0, 50, 100])
cbar.set_ticklabels(['0%', '50%', '100%'])          
figure = comp.get_figure()
figure.savefig("hetmap16.png", format='png', bbox_inches='tight')

#Figure 2 (figure 3 is the same, but with a different database)

plt.figure()
sns.set()
df = pd.DataFrame(heatMapFvaMinDictP)
fvaMax = sns.heatmap(df, cmap="rocket_r", linewidths=.5, xticklabels=True, cbar_kws={"orientation": "horizontal", "label": "Minimum average flux", "pad": 0.004})
fvaMax.set_xticklabels(fvaMax.get_xticklabels(), rotation=-90)
fvaMax.xaxis.tick_top() # x axis on top
fvaMax.xaxis.set_label_position('top')
fvaMax.tick_params(axis='y', labelleft=False)
figure = fvaMax.get_figure()
figure.savefig("fva1.png", format='png', bbox_inches='tight')

Solution

  • Seaborn builds upon matplotlib, which can be used for further customizing plots. plt.subplots(ncols=3, sharey=True, ...) creates three subplots with a shared y-axis. Adding ax=ax1 to sns.heatmap(..., ax=...) creates the heatmap on the desired subplot. Note that the return value of sns.heatmap is again that same ax.

    The following code shows an example. vmin and vmax are explicitly set for the first heatmap to make sure that both values will appear in the colorbar (the default colorbar runs between the minimum and maximum of the encountered values).

    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    import seaborn as sns
    
    sns.set()
    fig, (ax1, ax2, ax3) = plt.subplots(ncols=3, sharey=True, figsize=(20, 8))
    
    N = 20
    labels = [''.join(np.random.choice(list('abcdefghi '), 40)) for _ in range(N)]
    df = pd.DataFrame({'column 1': np.random.uniform(0, 100, N), 'column 2': np.random.uniform(0, 100, N)},
                      index=labels)
    sns.heatmap(df, cmap="coolwarm", linewidths=.5, xticklabels=True, yticklabels=True, ax=ax1, vmin=0, vmax=100,
                cbar_kws={"orientation": "horizontal", "label": "Pathway completeness", "pad": 0.004})
    ax1.set_xticklabels(ax1.get_xticklabels(), rotation=-90)
    ax1.xaxis.tick_top()  # x axis on top
    ax1.xaxis.set_label_position('top')
    cbar = ax1.collections[0].colorbar
    cbar.set_ticks([0, 50, 100])
    cbar.set_ticklabels(['0%', '50%', '100%'])
    
    for ax in (ax2, ax3):
        max_value = 10 if ax == ax2 else 1000
        df = pd.DataFrame({'column 1': np.random.uniform(0, max_value, N), 'column 2': np.random.uniform(0, max_value, N)},
                          index=labels)
        sns.heatmap(df, cmap="rocket_r", linewidths=.5, xticklabels=True, ax=ax,
                    cbar_kws={"orientation": "horizontal", "pad": 0.004,
                              "label": ("Minimum" if ax == ax2 else "Minimum") + " average flux"})
        ax.set_xticklabels(ax.get_xticklabels(), rotation=-90)
        ax.xaxis.tick_top()  # x axis on top
        ax.xaxis.set_label_position('top')
    
    plt.tight_layout()
    fig.savefig("subplots.png", format='png', bbox_inches='tight')
    plt.show()
    

    example plot