Search code examples
pythonmatplotlibseabornlegend

Seaborn heatmap legend disturb the plots order with subplots


I have 2 dataframes like this:

df1 = pd.DataFrame(np.random.randint(0,50,size=(5,5)), columns=list('ABCDE'))
df2 = pd.DataFrame(np.random.randint(0,50,size=(5,5)), columns=list('FGHIJ'))

I want to create 2 heatmaps side by side for these two dataframes. The dataranges of these dataframes are same so I want to use only one cbar as legend.

In this common color bar, I want to shrink the color bar size to half and move the color_bar label to right of the cbar.

f,axs = plt.subplots(1,2,sharex=False,figsize=[5,5],gridspec_kw={'height_ratios':[10],'width_ratios':[1,1]})

plot1=sns.heatmap(df1,square=True,vmax=50,vmin=0,cmap=sns.color_palette("viridis", as_cmap=True), cbar=True,cbar_kws = dict(orientation='horizontal',location='bottom',label='Scale',pad=0.1,shrink=0.5),ax=axs[0])
plot2=sns.heatmap(df2,square=True,vmax=50,vmin=0,cmap=sns.color_palette("viridis", as_cmap=True),\
    cbar=False,ax=axs[1])
plt.show()

My output looks like this:

enter image description here

Here, the two heatmaps are not on the same level.

Can anyone help me how to get these heatmaps to the same level and move the color_bar label to right of the color bar?


Solution

  • You could place a color bar directly to the figure with matplotlib.figure.Figure.colorbar() and bind it to both of the axes by the proper value of "ax" argument. The color bar label could be moved to the right by setting the "loc" argument of the matplotlib.colorbar.Colorbar.set_label() funciton to "right":

    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    import seaborn as sns
    import matplotlib.cm as cm
    import matplotlib.colors as mcolors
    
    df1 = pd.DataFrame(np.random.randint(0, 50,size=(5, 5)), columns=list('ABCDE'))
    df2 = pd.DataFrame(np.random.randint(0, 50,size=(5, 5)), columns=list('FGHIJ'))
    
    
    f,axs = plt.subplots(1, 2, sharex=False, figsize=[5, 5], gridspec_kw={'height_ratios':[10],'width_ratios':[1, 1]})
    cmap = sns.color_palette("viridis", as_cmap=True)  # common cmap for both plots and their colorbar
    plot1 = sns.heatmap(df1, square=True, vmax=50, vmin=0, cmap=cmap, cbar=False, ax=axs[0])
    plot2 = sns.heatmap(df2, square=True, vmax=50, vmin=0, cmap=cmap, cbar=False, ax=axs[1])
    # place a common color bar with the respective scalar mappable 
    sm = cm.ScalarMappable(norm=mcolors.Normalize(vmin=0, vmax=50), cmap=cmap)
    cbar_kws = dict(orientation='horizontal', location='bottom', label='Scale', pad=0.1, shrink=0.5, ticks=[0, 25, 50])
    cb = f.colorbar(sm, ax=axs, **cbar_kws)
    cb.set_label('Scale', loc='right')
    plt.show()
    

    The result:

    enter image description here