I have a dataframe where I need to select level 0 column for each index which has more values: I prepared one example to clarify. For each element T1,T2,T3 I need to select the level 0 of the multiindex column (M1,M2,M3) where its size is greater. Also is presented the expected RESULT, as well as an approach I managed to get the RESULT.
import pandas as pd
import numpy as np
data= [[1,1,1,0,0,0,0,0,0],
[1,1,1,0,0,0,0,0,0],
[1,1,1,0,0,0,0,0,0],
[0,0,0,2,2,2,1,1,1],
[0,0,0,0,0,0,1,1,1],
[0,0,0,2,2,2,1,1,1],
[0,0,0,1,1,1,0,0,0],
[0,0,0,1,1,1,0,0,0],
[0,0,0,1,1,1,0,0,0]]
columns =pd.MultiIndex.from_product([['M1','M2','M3'],['x','y','z']])
index = ['T1','T1','T1','T2','T2','T2','T3','T3','T3']
df = pd.DataFrame(data, index=index, columns = columns ).replace(to_replace = 0, value=np.nan)
df
M1 M2 M3
x y z x y z x y z
T1 1.00 1.00 1.00 NaN NaN NaN NaN NaN NaN
T1 1.00 1.00 1.00 NaN NaN NaN NaN NaN NaN
T1 1.00 1.00 1.00 NaN NaN NaN NaN NaN NaN
T2 NaN NaN NaN 2.00 2.00 2.00 1.00 1.00 1.00
T2 NaN NaN NaN NaN NaN NaN 1.00 1.00 1.00
T2 NaN NaN NaN 2.00 2.00 2.00 1.00 1.00 1.00
T3 NaN NaN NaN 1.00 1.00 1.00 NaN NaN NaN
T3 NaN NaN NaN 1.00 1.00 1.00 NaN NaN NaN
T3 NaN NaN NaN 1.00 1.00 1.00 NaN NaN NaN
#Expected RESULT
# x y z
# T1 1.00 1.00 1.00
# T1 1.00 1.00 1.00
# T1 1.00 1.00 1.00
# T2 1.00 1.00 1.00
# T2 1.00 1.00 1.00
# T2 1.00 1.00 1.00
# T3 1.00 1.00 1.00
# T3 1.00 1.00 1.00
# T3 1.00 1.00 1.00
# Approach
select = df.stack(level=0).count(axis=1).reset_index().groupby(['level_0','level_1']).sum().unstack(level=1).idxmax(axis=1)
# From this select ( correctly selected M for each T I would like to go to the final RESULTt.
select
Out[52]:
level_0
T1 (0, M1)
T2 (0, M3)
T3 (0, M2)
IIUC, you can use :
#last column in the topmost level
last_col = df.columns.get_level_values(0)[-1] #or `df.columns[-1][0]`
out = (df.stack().ffill(axis=1)[[last_col]]
.assign(idx=lambda x: x.groupby(level=[0, 1]).cumcount()).reset_index()
.pivot(index=["level_0", "idx"], columns="level_1", values=last_col)
.droplevel(1).rename_axis(index=None, columns=None)
)
Output :
print(out)
x y z
T1 1.0 1.0 1.0
T1 1.0 1.0 1.0
T1 1.0 1.0 1.0
T2 1.0 1.0 1.0
T2 1.0 1.0 1.0
T2 1.0 1.0 1.0
T3 1.0 1.0 1.0
T3 1.0 1.0 1.0
T3 1.0 1.0 1.0