I have a dataframe with 1 column that contains list element, and 1 column with integer. I would like to group all lists which have at least one element in common, and than aggregate the other column.
import pandas as pd
import json
import networkx as nx
data = {'lot': [['6309025'],
['6309025', '6375538', '6375540'],
['6410558'], ['6314113']],
'count': [1, 2, 3, 3]}
df = pd.DataFrame(data)
df['id'] = df.index
df = df.explode('lot')
G = nx.from_pandas_edgelist(df, 'lot', 'id')
l = list(nx.connected_components(G))
L = [dict.fromkeys(y, x) for x, y in enumerate(l)]
d = {k: v for d in L for k, v in d.items()}
s = df.groupby(df.id.map(d)).lot.apply(set)
I used the solution from this question. However, I can't find a way to aggregate the count
column.
Expected output :
lot count
0 {6309025, 6410558, 6375540, 6375538} 3
1 {6410558} 3
2 {6314113} 3
Any thoughts ?
You could mask the duplicated values in your exploded DataFrame, create a unique grouper per connected components, and aggregate:
import networkx as nx
tmp = df.explode('lot')
G = nx.from_pandas_edgelist(tmp.reset_index(), source='lot', target='index')
S = set(tmp['lot'])
mapper = {n: i for i, c in enumerate(nx.connected_components(G))
for n in c if n in S}
out = (tmp
.assign(count=lambda x: x['count'].mask(x.index.duplicated())
.convert_dtypes())
.groupby(tmp['lot'].map(mapper), as_index=False)
.agg({'lot': set, 'count': 'sum'})
)
NB. to build the graph you have to make sure that the indices are unique and not overlapping with the lot values (if not, build the graph with G = nx.from_edgelist(zip('idx:'+tmp.index.astype(str), tmp['lot']))
).
Output:
lot count
0 {6375540, 6309025, 6375538} 3
1 {6410558} 3
2 {6314113} 3
Intermediate:
lot count group masked_count
0 6309025 1 0 1
1 6309025 2 0 2
1 6375538 2 0 <NA>
1 6375540 2 0 <NA>
2 6410558 3 1 3
3 6314113 3 2 3
You could also group using the original DataFrame, by taking the first item of each list and with itertools.chain
. This might be more efficient as you won't need to mask and will process a smaller number of rows.
from itertools import chain
out = (df.groupby(df['lot'].str[0].map(mapper), as_index=False)
.agg({'lot': lambda x: set(chain.from_iterable(x)),
'count': 'sum'
})
)
Graph: