Search code examples
pythonpython-polars

How do I select the top k rows of a python polars dataframe for each group?


The polars dataframe has a top_k method that can be used to select rows which contain the k largest values when sorting on a column. For example, the following code selects the two rows with the largest and second largest entry in the val column:

df = pl.DataFrame({'grp':['a','a','a','b','b','b'], 'val':[1,2,3,10,20,30], 'etc':[0,1,2,3,4,5]})

grp val etc
str i64 i64
"a" 1   0
"a" 2   1
"a" 3   2
"b" 10  3
"b" 20  4
"b" 30  5

df.top_k(2, by='val')

grp val etc
str i64 i64
"b" 30  5
"b" 20  4

My question is: how do I get the rows with top k values for each group? Specifically, I want the entire row and not just the value in the val column. I want to do something like this, but this doesn't work in polars because polars GroupBy doesn't have a top_k method:

df.groupby('grp').top_k(2, by='val') # doesnt work in polars

grp val etc
str i64 i64
"b" 30  5
"b" 20  4
"a" 3   2
"a" 2   1

I was able to come up with two ways: one using map_groups and another using sorting. Both of these are not desirable for performance reasons. map_groups is generally not recommended because it's almost always significantly slower. The sorting option is also not desirable as getting the top k elements uses a faster algorithm than sorting (for small k and large n, it's basically O(n) vs O(n log n)). So even though the following below work, I'm looking for other approaches. Is there any way to directly use a top_k method with polars groupby? That would be my ideal solution.

# works, but at expense of using map_groups method
df.group_by('grp').map_groups(lambda df: df.top_k(2, by='val'))

grp val etc
str i64 i64
"b" 30  5
"b" 20  4
"a" 3   2
"a" 2   1
# works, but at expense of sorting entire groups
df.group_by('grp').agg(pl.all().sort_by('val', descending=True).head(2)).explode('val','etc')

grp val etc
str i64 i64
"a" 3   2
"a" 2   1
"b" 30  5
"b" 20  4
  • df.group_by('grp').top_k(2, by='val'), which doesn't work in polars
  • df.group_by('grp').map_groups(lambda df: df.top_k(2, by='val')), which works at the cost of using map_groups
  • df.group_by('grp').agg(pl.all().sort_by('val', descending=True).head(2)).explode('val','etc'), which works at the cost of sorting

Solution

  • The latest release (version 0.20.24) of polars introduced pl.Expr.top_k_by (and also pl.Expr.bottom_k_by) with optimised runtime complexity O(n + k log n - k / 2) for precisely the use-case mentioned in the question.

    It can be used jointly with pl.Expr.over and mapping_strategy="explode" to obtain the desired result.

    df.select(
        pl.all().top_k_by("val", k=2).over("grp", mapping_strategy="explode")
    )
    
    shape: (4, 3)
    ┌─────┬─────┬─────┐
    │ grp ┆ val ┆ etc │
    │ --- ┆ --- ┆ --- │
    │ str ┆ i64 ┆ i64 │
    ╞═════╪═════╪═════╡
    │ a   ┆ 3   ┆ 2   │
    │ a   ┆ 2   ┆ 1   │
    │ b   ┆ 30  ┆ 5   │
    │ b   ┆ 20  ┆ 4   │
    └─────┴─────┴─────┘
    

    Note. The call to pl.Expr.over with mapping_strategy="explode" is equivalent to the following aggregation.

    df.group_by("grp").agg(pl.all().top_k_by("val", k=2)).explode(pl.exclude("grp"))