Search code examples
pythondictionarymatplotlibplotbar-chart

matplotlib barplot with groups using a dictionary of lists of lists


I have some measurements of hardness of steels after quenching in different coolants: Water, Oil and Air (just leaving it to cool down). The measurements are organized as following: A dict, called coolant_data, contains three string:list pairs. Each string is a coolant, and each list is the measurements from a it. Inside each list, there are three lists containing all the measurements from three samples.

I have calculated the means and standard deviations of the measurements from each sample, and placed them in coolant_samples and coolant_samples_stds, accordingly. I want to plot all the data from coolant_samples, with coolant_samples_stds as the errorbars, in a bar chart using plt. So far so easy.

The part I'm having trouble with is this: The columns from each list should be adjecent, in the same group. Meaning, the groups should be organized by coolant, with each group containing three columns for the means of the measurements of the three samples.

So far I have the following code:

# Hardness data [HRC]
coolant_data = {
    "Water": [[27.0, 29.0, 30.0, 28.5, 27.5], [21.5, 29.0, 28.5, 21.0, 30.0], [25.0, 22.0, 28.0, 31.0, 26.0]],
    "Oil": [[11.5, 10.0, 11.5, 9.5, 4.5], [11.0, 12.0, 12.0, 11.0, 12.0], [9.5, 10.0, 11.0, 10.5, 11.0]],
    "Air": [[2.5, 3.0, 3.0, 3.5, 1.0], [2.0, 1.5, 3.0, 4.0, 3.5], [2.0, 1.5, 3.0, 2.0, 1.5]]}

# Calculate means and standard deviations
coolant_samples = {coolant: [np.mean(sample) for sample in measurements] for coolant, measurements in coolant_data.items()}
coolant_sample_stds = {coolant: [np.std(sample) for sample in measurements] for coolant, measurements in coolant_data.items()}

# Plot the hardness data as a bar chart with error bars for each sample and the mean
plt.figure()
plt.title("Hardness of Samples After Quenching in Different Coolants")
plt.ylabel("Hardness [HRC]")
labels = coolant_samples.keys()

# Create the bars with grouped x-axis values
x = range(len(labels))
width = 0.25  # Width of each bar
plt.bar_label(plt.bar([i - width for i in x], coolant_samples['Water'], width, label='Sample 1', yerr=coolant_sample_stds['Water']), padding=3)
plt.bar_label(plt.bar([i for i in x], coolant_samples['Oil'], width, label='Sample 2', yerr=coolant_sample_stds['Oil']), padding=3)
plt.bar_label(plt.bar([i + width for i in x], coolant_samples['Air'], width, label='Sample 3', yerr=coolant_sample_stds['Air']), padding=3)

plt.xticks(x, labels)
plt.ylim(bottom=0, top=35)
plt.legend(loc='upper left', ncols=3)
plt.show()

And I'm getting this graph:

A bar chart with the data grouped incorrectly

As you can see, the columns are groups incorrectly. I would really appreciate help on this.


Solution

  • One easy option would be to use . Although not ideal performance-wise to handle lists, this is easily done with map:

    import pandas as pd
    
    df = pd.DataFrame(coolant_data).T
    
    avg = df.map(np.mean)
    
    ax = avg.plot.bar()
    for c in ax.containers:
        ax.bar_label(c)
    

    Output:

    enter image description here

    With error bars:

    import pandas as pd
    from matplotlib.container import BarContainer
    
    df = pd.DataFrame(coolant_data).T
    
    avg = df.map(np.mean)
    std = df.map(np.std)
    
    ax = avg.plot.bar(yerr=std)
    for c in ax.containers:
        if isinstance(c, BarContainer):
            ax.bar_label(c)
    

    Output:

    enter image description here

    using
    import pandas as pd
    import seaborn as sns
    
    df = (pd.DataFrame({(k, f'sample {s}'): l for k, lst in coolant_data.items()
                        for s, l in enumerate(lst, start=1)})
            .rename_axis(columns=['coolant', 'sample'])
            .melt(value_name='hardness')
         )
    
    ax = sns.barplot(df, x='coolant', hue='sample', y='hardness', errorbar='sd')
    for c in ax.containers:
        ax.bar_label(c, label_type='center')
    

    Output:

    enter image description here