Search code examples
pythonpandasdataframegroup-by

Group rows that contains at least one common element in list from same column and aggregate other column


I have a dataframe with 1 column that contains list element, and 1 column with integer. I would like to group all lists which have at least one element in common, and than aggregate the other column.

import pandas as pd
import json
import networkx as nx

data = {'lot': [['6309025'],
                   ['6309025', '6375538', '6375540'],
                   ['6410558'], ['6314113']],
        'count': [1, 2, 3, 3]}

df = pd.DataFrame(data)


df['id'] = df.index
df = df.explode('lot')

G = nx.from_pandas_edgelist(df, 'lot', 'id')

l = list(nx.connected_components(G))

L = [dict.fromkeys(y, x) for x, y in enumerate(l)]

d = {k: v for d in L for k, v in d.items()}

s = df.groupby(df.id.map(d)).lot.apply(set)

I used the solution from this question. However, I can't find a way to aggregate the count column.

Expected output :

    lot                                             count
 0  {6309025, 6410558, 6375540, 6375538}             3
 1  {6410558}                                        3
 2  {6314113}                                        3

Any thoughts ?


Solution

  • You could mask the duplicated values in your exploded DataFrame, create a unique grouper per connected components, and aggregate:

    import networkx as nx
    
    tmp = df.explode('lot')
    
    G = nx.from_pandas_edgelist(tmp.reset_index(), source='lot', target='index')
    S = set(tmp['lot'])
    
    mapper = {n: i for i, c in enumerate(nx.connected_components(G))
              for n in c if n in S}
    
    out = (tmp
           .assign(count=lambda x: x['count'].mask(x.index.duplicated())
                                             .convert_dtypes())
           .groupby(tmp['lot'].map(mapper), as_index=False)
           .agg({'lot': set, 'count': 'sum'})
          )
    

    NB. to build the graph you have to make sure that the indices are unique and not overlapping with the lot values (if not, build the graph with G = nx.from_edgelist(zip('idx:'+tmp.index.astype(str), tmp['lot']))).

    Output:

                               lot  count
    0  {6375540, 6309025, 6375538}      3
    1                    {6410558}      3
    2                    {6314113}      3
    

    Intermediate:

           lot  count  group  masked_count
    0  6309025      1      0             1
    1  6309025      2      0             2
    1  6375538      2      0          <NA>
    1  6375540      2      0          <NA>
    2  6410558      3      1             3
    3  6314113      3      2             3
    

    variant

    You could also group using the original DataFrame, by taking the first item of each list and with itertools.chain. This might be more efficient as you won't need to mask and will process a smaller number of rows.

    from itertools import chain
    
    out = (df.groupby(df['lot'].str[0].map(mapper), as_index=False)
             .agg({'lot': lambda x: set(chain.from_iterable(x)),
                   'count': 'sum'
                  })
          )
    

    Graph:

    enter image description here