Search code examples
pythonpandasgroup-byaggregation

Pandas: Tidy up groupby aggregation


I really struggle with tidying up the table into a "normal" dataframe again after having aggregated something. I had a table like that (columns):

RnnSize     EmbSize     RnnLayer    Epochs  Alpha   Eval    Run     Result

So I calculated average and std of the Result column over multiple runs using that command:

df.groupby(["RnnSize", "EmbSize", "RnnLayer", "Epochs", "Alpha", "Eval"]).agg({'Result': ['mean', 'std']})

The output is a DataFrame like that:

                                                             Result
                                                             mean   std
RnnSize     EmbSize     RnnLayer    Epochs  Alpha   Eval        

It looks a bit like three levels.

df.columns outputs the following multiindex:

MultiIndex([(   'index',    ''),
            ( 'RnnSize',    ''),
            ( 'EmbSize',    ''),
            ('RnnLayer',    ''),
            (  'Epochs',    ''),
            (   'Alpha',    ''),
            (    'Eval',    ''),
            (  'Result', 'std'),
            (  'Result', 'std')],
           )

How do I flatten that again, removing "Result" and putting mean and std into the same "level" as the rest? There are so many commands like reset_index, drop_level and so on, but I did not find out yet how to fix that. It quite confuses me.

Edit: For reproducability, here is my entire code:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

dfRuns = pd.read_csv("Results.csv", encoding="utf-8")
dfRuns

dfAv = dfRuns.copy()
dfAv = dfAv.groupby(["RnnSize", "EmbSize", "RnnLayer", "Epochs", "Alpha", "Eval"]).agg({'Result': ['mean', 'std']})

And the (shortened) csv file Results.csv:

RnnSize,EmbSize,RnnLayer,Epochs,Alpha,Eval,Run,Result
128,200,2,150,0.1,Precision,1,0.5940
128,200,2,150,0.1,Recall,1,0.5038
128,200,2,150,0.1,F1,1,0.5144
128,200,2,150,0.1,Precision,2,0.5851
128,200,2,150,0.1,Recall,2,0.4995
128,200,2,150,0.1,F1,2,0.5082

Solution

  • Use reset_index() and then flatten the indexes:

    df = df.reset_index()
    df.columns = [' '.join(col).rstrip() for col in df.columns.to_numpy()]