Imagine I have a dictionary with string keys and integer values, like:
{'255': 8,
'323': 1,
'438': 1,
'938': 2,
'166': 1,
'117': 10,
'777': 2
}
I would like to create a new dictionary that contains the following information:
In the above example, the result would look something like:
{
'group1' : 12, # Sum of key values for '255' (8), '323' (1), '438' (1), '938' (2)
'group2' : 13 # Sum of key values for '166' (1), '117' (10), '777' (2)
}
To clarify, group1
is created in the following way: 255
has a 2 which matches the 2 in 323
. Then the 3's in 323
also match the 3 in 438
. Finally, the 8 (or the 3) in 438
match the 8 (or 3) in 938
. So adding the values for these keys in the same order we get 8 + 1 + 1 + 2 = 12.
None of the above key values (group1
) are merged with the remaining key values (for keys 166
, 117
, 777
; making up group2
) because none of the characters from the group1
keys match any of the characters in the group2
keys.
Group2 is created in the same way by matching the 1 from 166
to the 1s in 117
and matching the 7 in 117
to the 7s in 777
.
Final notes:
group1
and group2
for convenienceWhat is an efficient way to accomplish this?
I thought about converting the key values into a numpy
matrix where the rows represent a single key and the columns represent the values of each character. However, going down this path gets fairly complicated quickly and I'd rather avoid it if there is a better way.
You can use a union-find data structure to solve this problem. The networkx
package provides an implementation, but there's nothing stopping you from writing your own.
In essence, we maintain a collection of disjoint sets. Initially, every string belongs to its own disjoint set. For each pair of strings, if they have letters in common, we union the disjoint sets they belong to. This eventually gives us the groups that we're looking for.
From here, we use the .to_sets()
method to get the groupings, and compute the desired sum:
from networkx.utils.union_find import UnionFind
data = # dictionary in question, omitted for brevity
keys = list(data.keys())
uf = UnionFind(data.keys())
for outer_idx in range(len(keys)):
for inner_idx in range(outer_idx + 1, len(keys)):
if set(keys[outer_idx]) & set(keys[inner_idx]):
uf.union(keys[outer_idx], keys[inner_idx])
result = {}
for idx, group in enumerate(uf.to_sets()):
result[idx] = sum(data[key] for key in group)
print(result)
This outputs:
{0: 12, 1: 13}