Suppose I have a data set with a group column, subgroup column, and value column:
import polars as pl
df = pl.DataFrame(dict(
grp = ['A', 'A', 'A', 'B', 'B', 'B'],
subgroup = ['x', 'x', 'y', 'x', 'x', 'y'],
value = [1, 2, 3, 4, 5, 6]
))
┌─────┬──────────┬───────┐
│ grp ┆ subgroup ┆ value │
│ --- ┆ --- ┆ --- │
│ str ┆ str ┆ i64 │
╞═════╪══════════╪═══════╡
│ A ┆ x ┆ 1 │
│ A ┆ x ┆ 2 │
│ A ┆ y ┆ 3 │
│ B ┆ x ┆ 4 │
│ B ┆ x ┆ 5 │
│ B ┆ y ┆ 6 │
└─────┴──────────┴───────┘
I'd like to calculate the mean value
for each group
, and also calculate the mean of the values where subgroup
is "x". In polars, the most succinct way I've found to do this is:
(
df
.group_by('grp')
.agg(
pl.col('value').mean().alias('mean_all'),
pl.when(pl.col('subgroup') == 'x').then(pl.col('value')).mean().alias('mean_x')
)
)
┌─────┬──────────┬────────┐
│ grp ┆ mean_all ┆ mean_x │
│ --- ┆ --- ┆ --- │
│ str ┆ f64 ┆ f64 │
╞═════╪══════════╪════════╡
│ A ┆ 2.0 ┆ 1.5 │
│ B ┆ 5.0 ┆ 4.5 │
└─────┴──────────┴────────┘
This works fine, but the call to when...then
inside agg
produces a warning about this not being a valid aggregation (even though the chain is finished with mean
). Is there a more idiomatic or elegant way to perform this operation from inside a polars chain?
You can filter within the groupby to select subgroup 'x'.
(
df
.group_by('grp', maintain_order=True)
.agg(
pl.col('value').mean().alias('mean_all'),
pl.col("value").filter(pl.col("subgroup") == "x").mean().alias("mean_x")
)
)
This produces
shape: (2, 3)
┌─────┬──────────┬────────┐
│ grp ┆ mean_all ┆ mean_x │
│ --- ┆ --- ┆ --- │
│ str ┆ f64 ┆ f64 │
╞═════╪══════════╪════════╡
│ A ┆ 2.0 ┆ 1.5 │
│ B ┆ 5.0 ┆ 4.5 │
└─────┴──────────┴────────┘