Search code examples
pythonmatplotlibseabornbar-chart

Seaborn Barplot with Varying Number of Bars per Group without Blank Spaces


I'm using seaborn to plot the results of different algorithms. I want to distinguish both the different algorithms as well as their classification ("group"). The problem is that not all algorithms are in all groups, so when I use group as hue, I get a lot of blank space:

import seaborn as sns
group = ['Simple', 'Simple', 'Complex', 'Complex', 'Cool']
alg = ['Alg 1', 'Alg 2', 'Alg 3', 'Alg 4', 'Alg 2']
results = [i+1 for i in range(len(group))]
sns.barplot(group, results, hue=alg)

barplot

As you can see, seaborn makes space for bars from all algorithms to be in all groups, leading to lots of blank space. How can I avoid that? I do want to show the different groups on the x-axis and distinguish the different algorithms by color/style. Algorithms my occur in multiple but not all groups. But I just want space for 2 bars in "Simple" and "Complex" and just for 1 in "Cool". Any solutions with pure matplotlib are also welcome; it doesn't need to be seaborn. I'd like to keep the seaborn color palette though.


Solution

  • There doesn't seem to be a standard way to create this type of grouped barplot. The following code creates a list of positions for the bars, their colors, and lists for the labels and their positions.

    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    from matplotlib.patches import Patch
    
    group = ['Simple', 'Simple', 'Complex', 'Complex', 'Cool']
    alg = ['Alg 1', 'Alg 2', 'Alg 3', 'Alg 4', 'Alg 2']
    colors = plt.cm.tab10.colors
    alg_cat = pd.Categorical(alg)
    alg_colors = [colors[c] for c in alg_cat.codes]
    
    results = [i + 1 for i in range(len(group))]
    
    dist_groups = 0.4 # distance between successive groups
    pos = (np.array([0] + [g1 != g2 for g1, g2 in zip(group[:-1], group[1:])]) * dist_groups + 1).cumsum()
    labels = [g1 for g1, g2 in zip(group[:-1], group[1:]) if g1 != g2] + group[-1:]
    label_pos = [sum([p for g, p in zip(group, pos) if g == label]) / len([1 for g in group if g == label])
                 for label in labels]
    plt.bar(pos, results, color=alg_colors)
    plt.xticks(label_pos, labels)
    handles = [Patch(color=colors[c], label=lab) for c, lab in enumerate(alg_cat.categories)]
    plt.legend(handles=handles)
    plt.show()
    

    enter image description here


    from typing import List, Union
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    from matplotlib.patches import Patch
    
    
    def plot_grouped_barplot(group: List[str], alg: List[str], results: List[Union[int, float]]) -> None:
        """
        Plots a bar plot with varying numbers of bars per group centered over the group without blank spaces.
        
        Parameters:
        - group: List of group names (categories) for each bar.
        - alg: List of algorithm names corresponding to each bar.
        - results: List of result values (int or float) corresponding to each bar.
        """
        
        # Define colors using a color map
        colors = plt.cm.tab10.colors
        alg_cat = pd.Categorical(alg)
        alg_colors = [colors[c] for c in alg_cat.codes]
        
        # Calculate positions
        dist_groups = 0.4  # Distance between successive groups
        pos = (np.array([0] + [g1 != g2 for g1, g2 in zip(group[:-1], group[1:])]) * dist_groups + 1).cumsum()
        labels = [g1 for g1, g2 in zip(group[:-1], group[1:]) if g1 != g2] + [group[-1]]
        label_pos = [sum([p for g, p in zip(group, pos) if g == label]) / len([1 for g in group if g == label])
                     for label in labels]
    
        # Create the bar plot
        fig, ax = plt.subplots(figsize=(10, 6))
        ax.bar(pos, results, color=alg_colors)
        
        # Set x-ticks and labels
        ax.set_xticks(label_pos, labels)
        ax.set(xlabel='Group', ylabel='Results')
        
        # Create legend
        handles = [Patch(color=colors[c], label=lab) for c, lab in enumerate(alg_cat.categories)]
        ax.legend(handles=handles, title='Algorithm')
        
        # Show plot
        plt.show()
    
    • If the data is in lists
    # Sample data
    group = ['Simple', 'Simple', 'Complex', 'Complex', 'Cool']
    alg = ['Alg 1', 'Alg 2', 'Alg 3', 'Alg 4', 'Alg 2']
    results = [i + 1 for i in range(len(group))]
    
    # Plot the grouped bar plot
    plot_grouped_barplot(group, alg, results)
    
    • If the data is in a dataframe
    # Sample dataframe
    df = pd.DataFrame({'Group': group, 'Algorithm': alg, 'Results': results})
    
    # Plot the grouped bar plot
    plot_grouped_barplot(df.Group.tolist(), df.Algorithm.tolist(), df.Results.tolist())