Search code examples
pythonpandasmulti-index

Python: Keep n last columns of pandas.multiindex with respect to column index level 1


I have a pd.DataFrame with two levels of columns. I need to keep the last n columns of level 1 and delete all previous columns. The number of columns is not necessarily equal across all columns of level 0.

df = pd.DataFrame(np.random.randint(low=1, high=5, size=(4,12)))
df.columns = pd.MultiIndex.from_product([[1,2,3],['A','B', 'C', 'D']])
df.drop((2, 'A'), axis = 1, inplace = True)
df.drop((3, 'A'), axis = 1, inplace = True)
df.drop((3, 'C'), axis = 1, inplace = True)

   1           2        3   
   A  B  C  D  B  C  D  B  D
0  3  1  4  3  4  2  4  4  4
1  4  1  4  1  1  2  4  1  1
2  3  4  3  2  3  4  3  3  1
3  2  4  4  1  4  1  1  2  3

Expected result:

   1     2     3   
   C  D  C  D  B  D
0  4  3  2  4  4  4
1  4  1  2  4  1  1
2  3  2  4  3  3  1
3  4  1  1  1  2  3

Solution

  • Use GroupBy.cumcount with ascending=False for counter from back for first level of MultiIndex and filter last 2 columns in DataFrame.loc, also cumcount failed with levels in columns, so added MultiIndex.to_frame:

    df = df.loc[:, df.columns.to_frame().groupby(level=0).cumcount(ascending=False) < 2]
    print (df)
       1     2     3   
       C  D  C  D  B  D
    0  4  4  2  1  3  3
    1  1  1  2  1  4  2
    2  1  1  2  3  4  2
    3  4  3  1  3  4  4
    

    Details:

    print (df.columns.to_frame().groupby(level=0).cumcount(ascending=False))
    1  A    3
       B    2
       C    1
       D    0
    2  B    2
       C    1
       D    0
    3  B    1
       D    0
    dtype: int64
    
    print (df.columns.to_frame().groupby(level=0).cumcount(ascending=False) < 2)
    1  A    False
       B    False
       C     True
       D     True
    2  B    False
       C     True
       D     True
    3  B     True
       D     True
    dtype: bool
    

    Another idea with filter last columns and then filter by Index.isin:

    df = df.loc[:, df.columns.isin(df.columns.to_frame().groupby(level=0).tail(2).index)]
    print (df)
       1     2     3   
       C  D  C  D  B  D
    0  3  1  3  2  1  1
    1  3  2  4  2  3  1
    2  2  4  3  3  1  3
    3  1  3  4  1  3  3