Search code examples
pythonpandasdataframepandas-groupbydata-analysis

FIltering dataframe groups with "or" condition


I am dealing with a dataframe such this one:

    id        Xp_1  Xp_2   Xp_4   Xt_1  Xt_2  Xt_3  Mp_1   Mp_2  Mp_3  Mt_1  Mt_2 Mt_6
0    i24     Nan    0.27   Nan    0.45  0.20  0.25  0.27  Nan    Nan   Nan   Nan  Nan
1    i25     0.45   0.47   0.46   0.22  0.42  Nan   0.42  0.05   0.43  0.12  0.01  0.04
2    i11     Nan    Nan    0.32   0.14  0.32  0.35  0.29  0.33   Nan   Nan   0.02  0.44
3    i47     Nan    0.56   0.59   0.92  Nan   0.56  0.51  0.12   Nan   0.1   0.1   Nan

As you can see, I have something like two macro-groups (X and M), and for each macro-group two subsets (p and t). What I would like to implement is a "or" condition between the two macro-groups and a "and" condition between each subset of the macro-group.

Basically, I'd like to keep those lines that have at least two values for each subset in at least one group. For example: i24 should be discarded, in fact, we only have one value for the Xps, moreover, we don't have any value for the M group. Entries like i11 should be kept, in fact, the condition is not satisfied for the X group, but it is satisfied for the M. The same goes for i25, which satisfies the condition in both groups.

I tried this:

keep_r = (df.groupby(lambda col: col.split("_", maxsplit=1)[0], axis=1)
            .count()
            .ge(2)
            .all(axis=1))
df = df.loc[keep_r]

but it checks whether in all subsets (Xp, Xt, Mp, Mt) there are at least two values. Instead, I want to treat X and M independently.

Thank you!


Solution

  • IIUC Try creating a MultiIndex from pattern str.extract:

    df = df.set_index('id')
    df.columns = pd.MultiIndex.from_frame(df.columns.str.extract('(.)(.)_(.+)'))
    
    0       X                                   M                              
    1       p                 t                 p                 t            
    2       1     2     4     1     2     3     1     2     3     1     2     6
    id                                                                         
    i24   NaN  0.27   NaN  0.45  0.20  0.25  0.27   NaN   NaN   NaN   NaN   NaN
    i25  0.45  0.47  0.46  0.22  0.42   NaN  0.42  0.05  0.43  0.12  0.01  0.04
    i11   NaN   NaN  0.32  0.14  0.32  0.35  0.29  0.33   NaN   NaN  0.02  0.44
    i47   NaN  0.56  0.59  0.92   NaN  0.56  0.51  0.12   NaN  0.10  0.10   NaN
    

    Then groupby levels 0 and 1 to count then apply separate logic to each level.:

    keep = (
        df.groupby(axis=1, level=[0, 1]).count()
            .ge(2).all(axis=1, level=0).any(axis=1)
    )
    
    id
    i24    False
    i25     True
    i11     True
    i47     True
    dtype: bool
    

    Then filter down and collapse MultiIndex:

    df = df.loc[keep]
    df.columns = df.columns.map(lambda c: f'{"".join(c[:-1])}_{c[-1]}')
    df = df.reset_index()
    
        id  Xp_1  Xp_2  Xp_4  Xt_1  Xt_2  Xt_3  Mp_1  Mp_2  Mp_3  Mt_1  Mt_2  Mt_6
    0  i25  0.45  0.47  0.46  0.22  0.42   NaN  0.42  0.05  0.43  0.12  0.01  0.04
    1  i11   NaN   NaN  0.32  0.14  0.32  0.35  0.29  0.33   NaN   NaN  0.02  0.44
    2  i47   NaN  0.56  0.59  0.92   NaN  0.56  0.51  0.12   NaN  0.10  0.10   NaN