Search code examples
group-bypython-polars

What is the Polars equivalent of Pandas idxmin in a groupby aggregate?


I'm looking for a Polars equivalent of Pandas idxmin in a group_by agg operation

With Polars and this example dataframe:

import polars as pl

dct = {
    "a": [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2],
    "b": [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5],
    "c": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
    "d": [0.18418279, 0.67394382, 0.80951643, 0.10115085, 0.03998497, 0.17771175, 0.28428486,     0.24553192, 0.31388881, 0.07525366, 0.28622033, 0.61240989]
}

df = pl.DataFrame(dct)
a   b   c   d
i64 i64 i64 f64
0   0   0   0.184183
0   0   1   0.673944
0   1   2   0.809516
0   1   3   0.101151
1   2   4   0.039985
1   2   5   0.177712
1   3   6   0.284285
1   3   7   0.245532
2   4   8   0.313889
2   4   9   0.075254
2   5   10  0.28622
2   5   11  0.61241

and group_by operation:

df.group_by(["a", "b"], maintain_order=True).agg(pl.col('d').min())
shape: (6, 3)
a   b   d
i64 i64 f64
0   0   0.184183
0   1   0.101151
1   2   0.039985
1   3   0.245532
2   4   0.075254
2   5   0.28622

I need the values in column c of the original dataframe that correspond to the aggregated values in column d.

In Pandas I can get that result with:

import pandas as pd
df2 = pd.DataFrame(dct)
df2.iloc[df2.groupby(["a", "b"]).agg({"d": "idxmin"}).d]
    a   b   c   d
0   0   0   0   0.184183
3   0   1   3   0.101151
4   1   2   4   0.039985
7   1   3   7   0.245532
9   2   4   9   0.075254
10  2   5   10  0.286220

I've not found an equivalent function for idxmin in Polars. There is arg_min, but it does not give the desired result as can be seen in this output:

df.group_by(["a", "b"], maintain_order=True).agg(pl.col('d').arg_min())
shape: (6, 3)
a   b   d
i64 i64 u32
0   0   0
0   1   1
1   2   0
1   3   1
2   4   1
2   5   0

The actual dataframe has several million rows and takes a long time (minutes) in Pandas to compute. How can I achieve the same result in Polars efficiently?


Solution

  • You can use Expr.get to get from a column by index.

    df.group_by(["a", "b"], maintain_order=True).agg(
        c=pl.col("c").get(pl.col("d").arg_min()), d=pl.col("d").min(),
    )
    
    ┌─────┬─────┬─────┬──────────┐
    │ a   ┆ b   ┆ c   ┆ d        │
    │ --- ┆ --- ┆ --- ┆ ---      │
    │ i64 ┆ i64 ┆ i64 ┆ f64      │
    ╞═════╪═════╪═════╪══════════╡
    │ 0   ┆ 0   ┆ 0   ┆ 0.184183 │
    │ 0   ┆ 1   ┆ 3   ┆ 0.101151 │
    │ 1   ┆ 2   ┆ 4   ┆ 0.039985 │
    │ 1   ┆ 3   ┆ 7   ┆ 0.245532 │
    │ 2   ┆ 4   ┆ 9   ┆ 0.075254 │
    │ 2   ┆ 5   ┆ 10  ┆ 0.28622  │
    └─────┴─────┴─────┴──────────┘