Suppose I have a custom data structure Data
that reveals two relevant properties: tag
indicates which equivalence class this item belongs in, and rank
indicates how good this item is.
I have an unordered set of Data
objects, and want to retrieve the n
objects with the highest rank
—but with at most one object from each equivalence class.
(Objects in the same equivalence class don't necessarily compare equal, and don't necessarily have the same rank
, but I don't want any two elements in my output to come from the same class. In other words, the relation that produces these equivalence classes isn't ==
.)
My first approach looks something like this:
rank
s
tag
is in s
; if so, move ontag
to s
n
elements, stopHowever, this feels awkward, like there should be some better way (potentially using itertools
and higher-order functions). The order of the resulting n
elements isn't important.
Toy example:
Data = namedtuple('Data', ('tag', 'rank'))
n = 3
algorithm_input = { Data('a', 200), Data('a', 100), Data('b', 50), Data('c', 10), Data('d', 5) }
expected_output = { Data('a', 200), Data('b', 50), Data('c', 10) }
You could use itertools.groupby
(doc). First we sort the items by your criteria and then group them by tag (and store only first item from each group):
from itertools import groupby
from collections import namedtuple
Data = namedtuple('Data', ('tag', 'rank'))
n = 3
algorithm_input = { Data('a', 200), Data('a', 100), Data('b', 50), Data('c', 10), Data('d', 5) }
# 1. sort the data by rank (descending) and tag (ascending)
s = sorted(algorithm_input, key=lambda k: (-k.rank, k.tag))
# 2. group the data by tag and store first item from each group to 'out', limit the number of groups to 'n'
out = []
for (_, g), _ in zip(groupby(s, lambda k: k.tag), range(n)):
out.append(next(g))
print(out)
Prints:
[Data(tag='a', rank=200), Data(tag='b', rank=50), Data(tag='c', rank=10)]
EDIT: Changed the sorting key.