Given the following dataframe, I would like to group by "foo", sort on "bar", and then keep the whole row.
df = pl.DataFrame(
{
"foo": [1, 1, 1, 2, 2, 2, 3],
"bar": [5, 7, 6, 4, 2, 3, 1],
"baz": [1, 2, 3, 4, 5, 6, 7],
}
)
df_desired = pl.DataFrame({"foo": [1, 2, 3], "bar": [5, 2, 1], "baz": [1,5,7]})
>>> df_desired
shape: (3, 3)
┌─────┬─────┬─────┐
│ foo ┆ bar ┆ baz │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ i64 │
╞═════╪═════╪═════╡
│ 1 ┆ 5 ┆ 1 │
│ 2 ┆ 2 ┆ 5 │
│ 3 ┆ 1 ┆ 7 │
└─────┴─────┴─────┘
I can do this by sorting beforehand, but this is expensive compared to sorting the group:
df_solution = df.sort("bar").group_by("foo", maintain_order=True).first().sort(by="foo")
assert df_desired.equals(df_solution)
I can sort by "foo" in the aggregation, as in this SO answer:
>>> df.group_by("foo").agg(pl.col("bar").sort().first()).sort(by="foo")
shape: (3, 2)
┌─────┬─────┐
│ foo ┆ bar │
│ --- ┆ --- │
│ i64 ┆ i64 │
╞═════╪═════╡
│ 1 ┆ 5 │
│ 2 ┆ 2 │
│ 3 ┆ 1 │
└─────┴─────┘
but then I only get that column. How do I also keep "baz"'s row value? Any additional entries to .agg([])
are independent of the new pl.col("bar").sort()
.
You could get the index of the min row with pl.col("bar").arg_min()
Which can be given to pl.all().get()
to access the corresponding values.
(df.group_by("foo", maintain_order=True)
.agg(pl.all().get(pl.col("bar").arg_min())
)
shape: (3, 3)
┌─────┬─────┬─────┐
│ foo ┆ bar ┆ baz │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ i64 │
╞═════╪═════╪═════╡
│ 3 ┆ 1 ┆ 7 │
│ 2 ┆ 2 ┆ 5 │
│ 1 ┆ 5 ┆ 1 │
└─────┴─────┴─────┘
Or closer to your example, you could .sort_by()
inside .agg()
(df.group_by("foo", maintain_order=True)
.agg(pl.all().sort_by("bar").first())
)
shape: (3, 3)
┌─────┬─────┬─────┐
│ foo ┆ bar ┆ baz │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ i64 │
╞═════╪═════╪═════╡
│ 1 ┆ 5 ┆ 1 │
│ 3 ┆ 1 ┆ 7 │
│ 2 ┆ 2 ┆ 5 │
└─────┴─────┴─────┘