Search code examples
pythonpandasdataframe

Groupby if at least one element is in common


I have the following data frame

import pandas as pd
d1 = {'id': ["car", "car", "bus", "plane", "plane", "plane"], 'value': [["ab","b"], ["b","ab"], ["ab","b"], ["cd","df"], ["d","cd"], ["df","df"]]}
df = pd.DataFrame(data=d1)
df
     id      value
0   car     [ab, b]
1   car     [b, ab]
2   bus     [ab, b]
3   plane   [cd, df]
4   plane   [d, cd]
5   plane   [df, df]

I would like to group my ids if they have atleast one element from the value column in common. The desired output would look like this:


     id  value
0   car [ab, b]
1   car [b, ab]
2   bus [ab, b]
      id     value
0   plane   [cd, df]
1   plane   [d, cd]
      id     value
0   plane   [cd, df]
1   plane   [df, df]

I tried using groupby, but the problem is that some ids should be included in mutliple data frames, like

plane   [cd, df]

Solution

  • You can use set operations:

    keep = (df.explode('value').reset_index().groupby('value')['index'].agg(frozenset)
              .loc[lambda s: s.str.len()>1].unique()
           )
    
    for idx in keep:
        print(df.loc[idx])
    

    Output:

        id    value
    0  car  [ab, b]
    1  car  [b, ab]
    2  bus  [ab, b]
          id     value
    3  plane  [cd, df]
    4  plane   [d, cd]
          id     value
    3  plane  [cd, df]
    5  plane  [df, df]
    

    How it works

    first get the matching indices per value

    df.explode('value').reset_index().groupby('value')['index'].agg(frozenset)
    
    value
    ab    (0, 1, 2)
    b     (0, 1, 2)
    cd       (3, 4)
    d           (4)
    df       (3, 5)
    Name: index, dtype: object
    

    Remove duplicates, keep only groups of more than 1 member:

    keep = (df.explode('value').reset_index().groupby('value')['index'].agg(frozenset)
              .loc[lambda s: s.str.len()>1].unique()
           )
    
    [frozenset({0, 1, 2}), frozenset({3, 4}), frozenset({3, 5})]
    

    Finally, loop over the groups.

    alternative syntax (same logic)

    s = df['value'].explode()
    keep = dict.fromkeys(frozenset(x) for x in s.index.groupby(s).values() if len(x)>1)
    
    for idx in keep:
        print(df.loc[idx])