I'm looking for a Polars equivalent of Pandas idxmin in a group_by agg operation
With Polars and this example dataframe:
import polars as pl
dct = {
"a": [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2],
"b": [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5],
"c": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
"d": [0.18418279, 0.67394382, 0.80951643, 0.10115085, 0.03998497, 0.17771175, 0.28428486, 0.24553192, 0.31388881, 0.07525366, 0.28622033, 0.61240989]
}
df = pl.DataFrame(dct)
a b c d
i64 i64 i64 f64
0 0 0 0.184183
0 0 1 0.673944
0 1 2 0.809516
0 1 3 0.101151
1 2 4 0.039985
1 2 5 0.177712
1 3 6 0.284285
1 3 7 0.245532
2 4 8 0.313889
2 4 9 0.075254
2 5 10 0.28622
2 5 11 0.61241
and group_by operation:
df.group_by(["a", "b"], maintain_order=True).agg(pl.col('d').min())
shape: (6, 3)
a b d
i64 i64 f64
0 0 0.184183
0 1 0.101151
1 2 0.039985
1 3 0.245532
2 4 0.075254
2 5 0.28622
I need the values in column c of the original dataframe that correspond to the aggregated values in column d.
In Pandas I can get that result with:
import pandas as pd
df2 = pd.DataFrame(dct)
df2.iloc[df2.groupby(["a", "b"]).agg({"d": "idxmin"}).d]
a b c d
0 0 0 0 0.184183
3 0 1 3 0.101151
4 1 2 4 0.039985
7 1 3 7 0.245532
9 2 4 9 0.075254
10 2 5 10 0.286220
I've not found an equivalent function for idxmin in Polars. There is arg_min, but it does not give the desired result as can be seen in this output:
df.group_by(["a", "b"], maintain_order=True).agg(pl.col('d').arg_min())
shape: (6, 3)
a b d
i64 i64 u32
0 0 0
0 1 1
1 2 0
1 3 1
2 4 1
2 5 0
The actual dataframe has several million rows and takes a long time (minutes) in Pandas to compute. How can I achieve the same result in Polars efficiently?
You can use Expr.get
to get from a column by index.
df.group_by(["a", "b"], maintain_order=True).agg(
c=pl.col("c").get(pl.col("d").arg_min()), d=pl.col("d").min(),
)
┌─────┬─────┬─────┬──────────┐
│ a ┆ b ┆ c ┆ d │
│ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ i64 ┆ f64 │
╞═════╪═════╪═════╪══════════╡
│ 0 ┆ 0 ┆ 0 ┆ 0.184183 │
│ 0 ┆ 1 ┆ 3 ┆ 0.101151 │
│ 1 ┆ 2 ┆ 4 ┆ 0.039985 │
│ 1 ┆ 3 ┆ 7 ┆ 0.245532 │
│ 2 ┆ 4 ┆ 9 ┆ 0.075254 │
│ 2 ┆ 5 ┆ 10 ┆ 0.28622 │
└─────┴─────┴─────┴──────────┘