Search code examples
pythonpython-polars

Polars get grouped rows where column value is maximum


So consider this snippet

import polars as pl

df = pl.DataFrame({'class': ['a', 'a', 'b', 'b'], 'name': ['Ron', 'Jon', 'Don', 'Von'], 'score': [0.2, 0.5, 0.3, 0.4]})
df.group_by('class').agg(pl.col('score').max())

This gives me:

shape: (2, 2)
┌───────┬───────┐
│ class ┆ score │
│ ---   ┆ ---   │
│ str   ┆ f64   │
╞═══════╪═══════╡
│ a     ┆ 0.5   │
│ b     ┆ 0.4   │
└───────┴───────┘

But I want the entire row of the group that corresponded to the maximum score. I can do a join with the original dataframe like

sdf = df.group_by('class').agg(pl.col('score').max())
sdf.join(df, on=['class', 'score'])

To get

shape: (2, 3)
┌───────┬───────┬──────┐
│ class ┆ score ┆ name │
│ ---   ┆ ---   ┆ ---  │
│ str   ┆ f64   ┆ str  │
╞═══════╪═══════╪══════╡
│ a     ┆ 0.5   ┆ Jon  │
│ b     ┆ 0.4   ┆ Von  │
└───────┴───────┴──────┘

Is there any way to avoid the join and include the name column as part of the groupby aggregation?


Solution

  • You can use a sort_by expression to sort your observations in each group by score, and then use the last expression to take the last observation.

    For example, to take all columns:

    df.group_by('class').agg(
        pl.all().sort_by('score').last(),
    )
    
    shape: (2, 3)
    ┌───────┬──────┬───────┐
    │ class ┆ name ┆ score │
    │ ---   ┆ ---  ┆ ---   │
    │ str   ┆ str  ┆ f64   │
    ╞═══════╪══════╪═══════╡
    │ a     ┆ Jon  ┆ 0.5   │
    │ b     ┆ Von  ┆ 0.4   │
    └───────┴──────┴───────┘
    

    Edit: using over

    If you have more than one observation that is the max, another easy way to get all rows is to use over.

    For example, if your data has two students in class b ('Von' and 'Yvonne') who tied for highest score:

    df = pl.DataFrame(
        {
            "class": ["a", "a", "b", "b", "b"],
            "name": ["Ron", "Jon", "Don", "Von", "Yvonne"],
            "score": [0.2, 0.5, 0.3, 0.4, 0.4],
        }
    )
    df
    
    shape: (5, 3)
    ┌───────┬────────┬───────┐
    │ class ┆ name   ┆ score │
    │ ---   ┆ ---    ┆ ---   │
    │ str   ┆ str    ┆ f64   │
    ╞═══════╪════════╪═══════╡
    │ a     ┆ Ron    ┆ 0.2   │
    │ a     ┆ Jon    ┆ 0.5   │
    │ b     ┆ Don    ┆ 0.3   │
    │ b     ┆ Von    ┆ 0.4   │
    │ b     ┆ Yvonne ┆ 0.4   │
    └───────┴────────┴───────┘
    
    df.filter(pl.col('score') == pl.col('score').max().over('class'))
    
    shape: (3, 3)
    ┌───────┬────────┬───────┐
    │ class ┆ name   ┆ score │
    │ ---   ┆ ---    ┆ ---   │
    │ str   ┆ str    ┆ f64   │
    ╞═══════╪════════╪═══════╡
    │ a     ┆ Jon    ┆ 0.5   │
    │ b     ┆ Von    ┆ 0.4   │
    │ b     ┆ Yvonne ┆ 0.4   │
    └───────┴────────┴───────┘