Search code examples
pythonmatplotlibseabornheatmapsubplot

Ensure consistent column widths when using heatmap with subplot2grid


I'm trying to format my subplots, but, for some reason, i can not figure out why the the position does not remains flat for all of them. Right now, they look like this:

enter image description here

As you can see, i have two issues: 1. i don't know how to exclude the text labels (like "dates") and 2. I need to format the subplots to share the same axis, so they remain aligned. My code so far:

fig = plt.figure(figsize=(25, 15))
ax1 = plt.subplot2grid((23,20), (0,0), colspan=19, rowspan=17)
ax2 = plt.subplot2grid((23,20), (19,0), colspan=19, rowspan=1)

sns.set(font_scale=0.95)

sns.heatmap(pivot, ax= ax1, annot=True, fmt=".0f", robust=True, linewidth=0.1, square=True, cmap="Blues")
sns.heatmap((pd.DataFrame(pivot.sum(axis=0))).transpose(), ax=ax2, annot=True, fmt=".0f", robust=True, linewidth=0.1, square=True, cmap="Blues", xticklabels=False, yticklabels=False)

plt.show()

My dataframe is this:

dates   2020Q1  2020Q2  2020Q3  2020Q4  2021Q1  2021Q2  2021Q3
inicio                                                        
2020Q1    56.0    45.0    15.0     7.0     4.0     4.0     3.0
2020Q2     NaN   418.0   277.0    86.0    46.0    33.0    28.0
2020Q3     NaN     NaN   619.0   398.0   167.0   122.0    93.0
2020Q4     NaN     NaN     NaN  1163.0   916.0   521.0   319.0
2021Q1     NaN     NaN     NaN     NaN   976.0   680.0   363.0
2021Q2     NaN     NaN     NaN     NaN     NaN   811.0   559.0
2021Q3     NaN     NaN     NaN     NaN     NaN     NaN  1879.0

Solution

    • Changing square=True to square=False in seaborn.heatmap will all the columns to have the same width.
    • Labels can be removed by setting them as an empty string: ax1.set(xlabel='', ylabel='')
    • Tested in python 3.8.11, pandas 1.3.3, matplotlib 3.4.3, seaborn 0.11.2
    import panda as pd
    
    # test dataframe
    data = {'dates': ['2020Q1', '2020Q1', '2020Q1', '2020Q1', '2020Q1', '2020Q1', '2020Q1', '2020Q2', '2020Q2', '2020Q2', '2020Q2', '2020Q2', '2020Q2', '2020Q3', '2020Q3', '2020Q3', '2020Q3', '2020Q3', '2020Q4', '2020Q4', '2020Q4', '2020Q4', '2021Q1', '2021Q1', '2021Q1', '2021Q2', '2021Q2', '2021Q3'],
            'inicio': ['2020Q1', '2020Q2', '2020Q3', '2020Q4', '2021Q1', '2021Q2', '2021Q3', '2020Q2', '2020Q3', '2020Q4', '2021Q1', '2021Q2', '2021Q3', '2020Q3', '2020Q4', '2021Q1', '2021Q2', '2021Q3', '2020Q4', '2021Q1', '2021Q2', '2021Q3', '2021Q1', '2021Q2', '2021Q3', '2021Q2', '2021Q3', '2021Q3'],
            'values': [56.0, 45.0, 15.0, 7.0, 4.0, 4.0, 3.0, 418.0, 277.0, 86.0, 46.0, 33.0, 28.0, 619.0, 398.0, 167.0, 122.0, 93.0, 1163.0, 916.0, 521.0, 319.0, 976.0, 680.0, 363.0, 811.0, 559.0, 1879.0]}
    df = pd.DataFrame(data)
    
    # pivot the dataframe
    pv = df.pivot(index='dates', columns='inicio', values='values')
    
    # create figure and subplots
    fig = plt.figure(figsize=(20, 10))
    ax1 = plt.subplot2grid((20, 10), (0, 0), colspan=19, rowspan=17)
    ax2 = plt.subplot2grid((20, 10), (19, 0), colspan=19, rowspan=1)
    
    sns.set(font_scale=0.95)
    
    # create heatmap with square=False instead of True
    sns.heatmap(pv, ax=ax1, annot=True, fmt=".0f", robust=True, linewidth=0.1, square=False, cmap="Blues")
    sns.heatmap(pv.sum().to_frame().T, ax=ax2, annot=True, fmt=".0f", robust=True, linewidth=0.1, square=False, cmap="Blues", xticklabels=False, yticklabels=False)
    
    ax1.set_yticklabels(pv.columns, rotation=0)  # rotate the yticklabels
    ax1.set(xlabel='', ylabel='')  # remove x & y labels
    ax2.set(xlabel='', ylabel='')  # remove x & y labels
    
    plt.show()
    

    enter image description here