Search code examples
pythonpandasseabornbar-chartheatmap

How to align yticklabels when combining a barplot with heatmap


I have similar problems as this question; I am trying to combine three plots in Seaborn, but the labels on my y-axis are not aligned with the bars.

My code (now a working copy-paste example):

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm

### Generate example data
np.random.seed(123)
year = [2018, 2019, 2020, 2021]
task = [x + 2 for x in range(18)]
student = [x for x in range(200)]
amount = [x + 10 for x in range(90)]
violation = [letter for letter in "thisisjustsampletextforlabels"] # one letter labels

df_example = pd.DataFrame({

    # some ways to create random data
    'year':np.random.choice(year,500),
    'task':np.random.choice(task,500),
    'violation':np.random.choice(violation, 500),
    'amount':np.random.choice(amount, 500),
    'student':np.random.choice(student, 500)
})

### My code
temp = df_example.groupby(["violation"])["amount"].sum().sort_values(ascending = False).reset_index()
total_violations = temp["amount"].sum()
sns.set(font_scale = 1.2)


f, axs = plt.subplots(1,3,
                      figsize=(5,5),
                      sharey="row",
                      gridspec_kw=dict(width_ratios=[3,1.5,5]))

# Plot frequency
df1 = df_example.groupby(["year","violation"])["amount"].sum().sort_values(ascending = False).reset_index()
frequency = sns.barplot(data = df1, y = "violation", x = "amount", log = True, ax=axs[0])


# Plot percent
df2 = df_example.groupby(["violation"])["amount"].sum().sort_values(ascending = False).reset_index()
total_violations = df2["amount"].sum()
percent = sns.barplot(x='amount', y='violation', estimator=lambda x: sum(x) / total_violations * 100, data=df2, ax=axs[1])

# Pivot table and plot heatmap 
df_heatmap = df_example.groupby(["violation", "task"])["amount"].sum().sort_values(ascending = False).reset_index()
df_heatmap_pivot = df_heatmap.pivot("violation", "task", "amount")
df_heatmap_pivot = df_heatmap_pivot.reindex(index=df_heatmap["violation"].unique())
heatmap = sns.heatmap(df_heatmap_pivot, fmt = "d", cmap="Greys", norm=LogNorm(), ax=axs[2])
plt.subplots_adjust(top=1)


axs[2].set_facecolor('xkcd:white')
axs[2].set(ylabel="",xlabel="Task")

axs[0].set_xlabel('Total amount of violations per year')
axs[1].set_xlabel('Percent (%)')

axs[1].set_ylabel('')
axs[0].set_ylabel('Violation')

The result can be seen here:

Barplot and Heatmap

The y-labels are aligned according to my last plot, the heatmap. However, the bars in the bar plots are clipping at the top, and are not aligned to the labels. I just have to nudge the bars in the barplot -- but how? I've been looking through the documentation, but I feel quite clueless as of now.


Solution

    • See here that none of the y-axis ticklabels are aligned because multiple dataframes are used for plotting. It will be better to create a single dataframe, violations, with the aggregated data to be plotted. Start with the sum of amounts by violation, and then add a new percent column. This will insure the two bar plots have the same y-axis.
    • Instead of using .groupby and then .pivot, to create df_heatmap_pivot, use .pivot_table, and then reindex using violations.violation.
    • Tested in python 3.10, pandas 1.4.3, matplotlib 3.5.1, seaborn 0.11.2

    DataFrames and Imports

    import pandas as pd
    import numpy as np
    import matplotlib.pyplot as plt
    from matplotlib.colors import LogNorm
    
    # Generate example data
    year = [2018, 2019, 2020, 2021]
    task = [x + 2 for x in range(18)]
    student = [x for x in range(200)]
    amount = [x + 10 for x in range(90)]
    violation = list("thisisjustsampletextforlabels")  # one letter labels
    
    np.random.seed(123)
    df_example = pd.DataFrame({name: np.random.choice(group, 500) for name, group in
                               zip(['year', 'task', 'violation', 'amount', 'student'],
                                   [year, task, violation, amount, student])})
    
    # organize all of the data
    # violations frequency
    violations = df_example.groupby(["violation"])["amount"].sum().sort_values(ascending=False).reset_index()
    total_violations = violations["amount"].sum()
    
    # add percent
    violations['percent'] = violations.amount.div(total_violations).mul(100).round(2)
    
    # Use .pivot_table to create the pivot table
    df_heatmap_pivot = df_example.pivot_table(index='violation', columns='task', values='amount', aggfunc='sum')
    # Set the index to match the plot order of the 'violation' column 
    df_heatmap_pivot = df_heatmap_pivot.reindex(index=violations.violation)
    

    Plotting

    • Using sharey='row' is causing the alignment problem. Use sharey=False, and remove the yticklabels from axs[1] and axs[2], with axs[1 or 2].set_yticks([]).
      • This is the case because ylim for the heatmap is not the same as for barplot. As such, the heatmap is shifted.
      • .get_ylim() for axs[0] and axs[1] is (15.5, -0.5), which for axs[2] is (16.0, 0.0).
    • See How to add value labels on a bar chart for additional details and examples using .bar_label.
    # set seaborn plot format
    sns.set(font_scale=1.2)
    
    # create the figure and set sharey=False
    f, axs = plt.subplots(1, 3, figsize=(12, 12), sharey=False, gridspec_kw=dict(width_ratios=[3,1.5,5]))
    
    # Plot frequency
    sns.barplot(data=violations, x="amount", y="violation", log=True, ax=axs[0])
    
    # Plot percent
    sns.barplot(data=violations, x='percent', y='violation', ax=axs[1])
    
    # add the bar labels
    axs[1].bar_label(axs[1].containers[0], fmt='%.2f%%', label_type='edge', padding=3)
    # add extra space for the annotation
    axs[1].margins(x=1.3)
    
    # plot the heatmap
    heatmap = sns.heatmap(df_heatmap_pivot, fmt = "d", cmap="Greys", norm=LogNorm(), ax=axs[2])
    
    # additional formatting
    axs[2].set_facecolor('xkcd:white')
    axs[2].set(ylabel="", xlabel="Task")
    
    axs[0].set_xlabel('Total amount of violations per year')
    axs[1].set_xlabel('Percent (%)')
    
    axs[1].set_ylabel('')
    axs[0].set_ylabel('Violation')
    
    # remove yticks / labels
    axs[1].set_yticks([])  
    _ = axs[2].set_yticks([])
    

    enter image description here

    • Comment out the last two lines to verify the yticklabels are aligned for each axs.

    enter image description here

    DataFrame Views

    df_example.head()

       year  task violation  amount  student
    0  2020     2         i      84       59
    1  2019     2         u      12      182
    2  2020     5         s      20        9
    3  2020    11         u      56      163
    4  2018    17         t      59      125
    

    violations

       violation  amount  percent
    0          s    4869    17.86
    1          l    3103    11.38
    2          t    3044    11.17
    3          e    2634     9.66
    4          a    2177     7.99
    5          i    2099     7.70
    6          h    1275     4.68
    7          f    1232     4.52
    8          b    1191     4.37
    9          m    1155     4.24
    10         o    1075     3.94
    11         p     763     2.80
    12         r     762     2.80
    13         j     707     2.59
    14         u     595     2.18
    15         x     578     2.12
    

    df_heatmap_pivot

    task          2      3      4      5      6      7      8      9      10     11     12     13     14     15     16     17     18     19
    violation                                                                                                                              
    s           62.0   36.0  263.0  273.0  191.0  250.0  556.0  239.0  230.0  188.0  185.0  516.0  249.0  331.0  212.0  219.0  458.0  411.0
    l           83.0  245.0  264.0  451.0  155.0  314.0   98.0  125.0  310.0  117.0   21.0   99.0   98.0   50.0   40.0  268.0  192.0  173.0
    t          212.0  255.0   45.0  141.0   74.0  135.0   52.0  202.0  107.0  128.0  158.0    NaN  261.0  137.0  339.0  207.0  362.0  229.0
    e          215.0  315.0    NaN  116.0  213.0  165.0  130.0  194.0   56.0  355.0   75.0    NaN  118.0  189.0  160.0  177.0   79.0   77.0
    a          135.0    NaN  165.0  156.0  204.0  115.0   77.0   65.0   80.0  143.0   83.0  146.0   21.0   29.0  285.0   72.0  116.0  285.0
    i          209.0    NaN   20.0  187.0   83.0  136.0   24.0  132.0  257.0   56.0  201.0   52.0  136.0  226.0  104.0  145.0   91.0   40.0
    h           27.0    NaN  255.0    NaN   99.0    NaN   71.0   53.0  100.0   89.0    NaN  106.0    NaN  170.0   86.0   79.0  140.0    NaN
    f           75.0   23.0   99.0    NaN   26.0  103.0    NaN  185.0   99.0  145.0    NaN   63.0   64.0   29.0  114.0  141.0   38.0   28.0
    b           44.0   70.0   56.0   12.0   55.0   14.0  158.0  130.0    NaN   11.0   21.0    NaN   52.0  137.0  162.0    NaN  231.0   38.0
    m           86.0    NaN    NaN  147.0   74.0  131.0   49.0  180.0   94.0   16.0    NaN   88.0    NaN    NaN    NaN   51.0  161.0   78.0
    o          109.0    NaN   51.0    NaN    NaN    NaN   20.0  139.0  149.0    NaN  101.0   60.0    NaN  143.0   39.0   73.0   10.0  181.0
    p           16.0    NaN  197.0   50.0   87.0    NaN   88.0    NaN   11.0  162.0    NaN   14.0    NaN   78.0   45.0    NaN    NaN   15.0
    r            NaN   85.0   73.0   40.0    NaN    NaN   68.0   77.0    NaN   26.0  122.0  105.0    NaN   98.0    NaN    NaN    NaN   68.0
    j            NaN   70.0    NaN    NaN   73.0   76.0    NaN  150.0    NaN    NaN    NaN   81.0    NaN   97.0   97.0   63.0    NaN    NaN
    u          174.0   45.0    NaN    NaN   32.0    NaN    NaN   86.0   30.0   56.0   13.0    NaN   24.0    NaN    NaN   69.0   54.0   12.0
    x           69.0   29.0    NaN  106.0    NaN   43.0    NaN    NaN    NaN   97.0   56.0   29.0  149.0    NaN    NaN    NaN    NaN    NaN