The polars dataframe has a top_k method that can be used to select rows which contain the k largest values when sorting on a column. For example, the following code selects the two rows with the largest and second largest entry in the val
column:
df = pl.DataFrame({'grp':['a','a','a','b','b','b'], 'val':[1,2,3,10,20,30], 'etc':[0,1,2,3,4,5]})
grp val etc
str i64 i64
"a" 1 0
"a" 2 1
"a" 3 2
"b" 10 3
"b" 20 4
"b" 30 5
df.top_k(2, by='val')
grp val etc
str i64 i64
"b" 30 5
"b" 20 4
My question is: how do I get the rows with top k values for each group? Specifically, I want the entire row and not just the value in the val
column. I want to do something like this, but this doesn't work in polars because polars GroupBy
doesn't have a top_k
method:
df.groupby('grp').top_k(2, by='val') # doesnt work in polars
grp val etc
str i64 i64
"b" 30 5
"b" 20 4
"a" 3 2
"a" 2 1
I was able to come up with two ways: one using map_groups
and another using sorting. Both of these are not desirable for performance reasons. map_groups
is generally not recommended because it's almost always significantly slower. The sorting option is also not desirable as getting the top k elements uses a faster algorithm than sorting (for small k and large n, it's basically O(n) vs O(n log n)). So even though the following below work, I'm looking for other approaches. Is there any way to directly use a top_k
method with polars groupby? That would be my ideal solution.
# works, but at expense of using map_groups method
df.group_by('grp').map_groups(lambda df: df.top_k(2, by='val'))
grp val etc
str i64 i64
"b" 30 5
"b" 20 4
"a" 3 2
"a" 2 1
# works, but at expense of sorting entire groups
df.group_by('grp').agg(pl.all().sort_by('val', descending=True).head(2)).explode('val','etc')
grp val etc
str i64 i64
"a" 3 2
"a" 2 1
"b" 30 5
"b" 20 4
df.group_by('grp').top_k(2, by='val')
, which doesn't work in polarsdf.group_by('grp').map_groups(lambda df: df.top_k(2, by='val'))
, which works at the cost of using map_groups
df.group_by('grp').agg(pl.all().sort_by('val', descending=True).head(2)).explode('val','etc')
, which works at the cost of sortingThe latest release (version 0.20.24) of polars introduced pl.Expr.top_k_by
(and also pl.Expr.bottom_k_by
) with optimised runtime complexity O(n + k log n - k / 2) for precisely the use-case mentioned in the question.
It can be used jointly with pl.Expr.over
and mapping_strategy="explode"
to obtain the desired result.
df.select(
pl.all().top_k_by("val", k=2).over("grp", mapping_strategy="explode")
)
shape: (4, 3)
┌─────┬─────┬─────┐
│ grp ┆ val ┆ etc │
│ --- ┆ --- ┆ --- │
│ str ┆ i64 ┆ i64 │
╞═════╪═════╪═════╡
│ a ┆ 3 ┆ 2 │
│ a ┆ 2 ┆ 1 │
│ b ┆ 30 ┆ 5 │
│ b ┆ 20 ┆ 4 │
└─────┴─────┴─────┘
Note. The call to pl.Expr.over
with mapping_strategy="explode"
is equivalent to the following aggregation.
df.group_by("grp").agg(pl.all().top_k_by("val", k=2)).explode(pl.exclude("grp"))