Search code examples
python-3.xpandasmulti-index

How to filter rows within a pandas multi index to get a 1:1 relationship between the multi index levels


I am trying to filter an output I get down to get one row per multi level index. In this case the logic would be the row with the largest sum across its columns. I have added code that will create the starting output (start_df) that I want to transform into the end output (end_df).

import numpy as np
import pandas as pd


# Starting df
tuples = [
    (1, 1),
    (2, 9), (2, 4), (2, 3),
    (3, 2), (3, 11)
]

start_df = pd.DataFrame(
    {
       'col1':  [1,1,2,1,2,1],
       'col2': [1,1,1,1,2,1]
    },
    index=pd.MultiIndex.from_tuples(tuples)
    )


# Ending df
tuples = [
    (1, 1),
    (2, 4), 
    (3, 2), 
]

end_df = pd.DataFrame(
    {
       'col1':  [1,2,2],
       'col2': [1,1,2]
    },
    index=pd.MultiIndex.from_tuples(tuples)
    )

Solution

  • Try:

    x = start_df.loc[start_df.groupby(level=0).apply(lambda x: x.sum(axis=1).idxmax())]
    print(x)
    

    Prints:

         col1  col2
    1 1     1     1
    2 4     2     1
    3 2     2     2