Search code examples
pythonalgorithmsortingequivalence-classes

Sorting with equivalence classes in Python


Suppose I have a custom data structure Data that reveals two relevant properties: tag indicates which equivalence class this item belongs in, and rank indicates how good this item is.

I have an unordered set of Data objects, and want to retrieve the n objects with the highest rank—but with at most one object from each equivalence class.

(Objects in the same equivalence class don't necessarily compare equal, and don't necessarily have the same rank, but I don't want any two elements in my output to come from the same class. In other words, the relation that produces these equivalence classes isn't ==.)

My first approach looks something like this:

  • Sort the list by descending rank
  • Create an empty set s
  • For each element in the list:
    • Check if its tag is in s; if so, move on
    • Add its tag to s
    • Yield that element
    • If we've yielded n elements, stop

However, this feels awkward, like there should be some better way (potentially using itertools and higher-order functions). The order of the resulting n elements isn't important.

What's the Pythonic solution to this problem?

Toy example:

Data = namedtuple('Data', ('tag', 'rank'))
n = 3

algorithm_input = { Data('a', 200), Data('a', 100), Data('b', 50), Data('c', 10), Data('d', 5) }
expected_output = { Data('a', 200), Data('b', 50), Data('c', 10) }

Solution

  • You could use itertools.groupby (doc). First we sort the items by your criteria and then group them by tag (and store only first item from each group):

    from itertools import groupby
    from collections import namedtuple
    
    Data = namedtuple('Data', ('tag', 'rank'))
    
    n = 3
    
    algorithm_input = { Data('a', 200), Data('a', 100), Data('b', 50), Data('c', 10), Data('d', 5) }
    
    # 1. sort the data by rank (descending) and tag (ascending)
    s = sorted(algorithm_input, key=lambda k: (-k.rank, k.tag))
    
    # 2. group the data by tag and store first item from each group to 'out', limit the number of groups to 'n'
    out = []
    for (_, g), _ in zip(groupby(s, lambda k: k.tag), range(n)):
        out.append(next(g))
    
    print(out)
    

    Prints:

    [Data(tag='a', rank=200), Data(tag='b', rank=50), Data(tag='c', rank=10)]
    

    EDIT: Changed the sorting key.