Search code examples
pythonpandasmatplotlibseaborn

multi index plotting


I have some data where I've manipulated the dataframe using the following code:

import pandas as pd
import numpy as np

data = pd.DataFrame([[0,0,0,3,6,5,6,1],[1,1,1,3,4,5,2,0],[2,1,0,3,6,5,6,1],[3,0,0,2,9,4,2,1],[4,0,1,3,4,8,1,1],[5,1,1,3,3,5,9,1],[6,1,0,3,3,5,6,1],[7,0,1,3,4,8,9,1]], columns=["id", "sex", "split", "group0Low", "group0High", "group1Low", "group1High", "trim"])
data

#remove all where trim == 0
trimmed = data[(data.trim == 1)]
trimmed

#create df with columns to be split
columns = ['group0Low', 'group0High', 'group1Low', 'group1High']
to_split = trimmed[columns]
to_split

level_group = np.where(to_split.columns.str.contains('0'), 0, 1)
# output: array([0, 0, 1, 1])
level_low_high = np.where(to_split.columns.str.contains('Low'), 'low', 'high')
# output: array(['low', 'high', 'low', 'high'], dtype='<U4')

multi_level_columns = pd.MultiIndex.from_arrays([level_group, level_low_high], names=['group', 'val'])
to_split.columns = multi_level_columns
to_split.stack(level='group')

sex = trimmed['sex']
split = trimmed['split']
horizontalStack = pd.concat([sex, split, to_split], axis=1)
horizontalStack

finalData = horizontalStack.groupby(['split', 'sex', 'group'])
finalData.mean()

My question is, how do I plot the mean data using ggplot or seaborn such that for each "split" level I get a graph that looks like this:

enter image description here

At the bottom of the code you can see I've tried to split up the group factor so I can separate the bars, but that resulted in an error (KeyError: 'group') and I think that is related to the way I used multi indexing


Solution

  • I would use a factor plot from seaborn.

    Say you have data like this:

    import numpy as np
    import pandas
    
    import seaborn
    seaborn.set(style='ticks') 
    np.random.seed(0)
    
    groups = ('Group 1', 'Group 2')
    sexes = ('Male', 'Female')
    means = ('Low', 'High')
    index = pandas.MultiIndex.from_product(
        [groups, sexes, means], 
       names=['Group', 'Sex', 'Mean']
    )
    
    values = np.random.randint(low=20, high=100, size=len(index))
    data = pandas.DataFrame(data={'val': values}, index=index).reset_index()
    print(data)
    
         Group     Sex  Mean  val
    0  Group 1    Male   Low   64
    1  Group 1    Male  High   67
    2  Group 1  Female   Low   84
    3  Group 1  Female  High   87
    4  Group 2    Male   Low   87
    5  Group 2    Male  High   29
    6  Group 2  Female   Low   41
    7  Group 2  Female  High   56
    

    You can then create the factor plot with one command + plus an extra line to remove some redundant (for your data) x-labels:

    # Note: catplot used to be called factorplot
    fg = seaborn.catplot(x='Group', y='val', hue='Mean', 
                            col='Sex', data=data, kind='bar')
    fg.set_xlabels('')
    

    Which gives me:

    enter image description here