Search code examples
pythonseabornhistogram

how to make histogram for multivariate data in python seaborn?


I have a following data I would like to have a histogram like the following but I could not do it using python. Can anyone please help me how to do it in python?

Group Summer Winter Autumn Spring
bacteria 20 30 40 20
virus 30 50 20 20
fungi 50 20 40 60

enter image description here


Solution

  • You can transform the dataframe to long form, and then call sns.histplot() with weights=... and multiple='stack'.

    from matplotlib import pyplot as plt
    import seaborn as sns
    import pandas as pd
    import numpy as np
    from io import StringIO
    
    data_str = '''Group Summer  Winter  Autumn  Spring
    bacteria    20  30  40  20
    virus   30  50  20  20
    fungi   50  20  40  60'''
    df = pd.read_csv(StringIO(data_str), delim_whitespace=True)
    df_long = df.melt(id_vars='Group', var_name='Season')
    hue_order = ['fungi', 'virus', 'bacteria']
    sns.set()
    ax = sns.histplot(data=df_long, x='Season', hue='Group', hue_order=hue_order,
                      weights='value', multiple='stack',
                      palette=['orange', 'gold', 'tomato'])
    ax.legend(handles=ax.legend_.legendHandles, labels=hue_order, bbox_to_anchor=(1.02, 0.98), loc='upper left')
    ax.set_ylabel('Percentage')
    plt.tight_layout()
    plt.show()
    

    sns.histplot with weights and multiple='stack'

    PS: As mentioned in the comments, multiple='fill' is another option. In that case, the bars for each x-value get stretched to fill the total height. This would especially be interesting when the values in the dataframe would be counts (instead of percentages, such as seems to be the case with the example data).

    The code could then look like:

    from matplotlib.ticker import PercentFormatter
    
    # ... similar data preparation as in the other example
    
    ax = sns.histplot(data=df_long, x='Season', hue='Group', hue_order=hue_order,
                      weights='value', multiple='fill',
                      palette=['orange', 'gold', 'tomato'])
    ax.set_ylabel('Percentage')
    ax.yaxis.set_major_formatter(PercentFormatter(1))
    plt.tight_layout()
    plt.show()