Search code examples
pythonpython-polars

Polars: aggregating a subset of rows


Suppose I have a data set with a group column, subgroup column, and value column:

import polars as pl

df = pl.DataFrame(dict(
    grp = ['A', 'A', 'A', 'B', 'B', 'B'],
    subgroup = ['x', 'x', 'y', 'x', 'x', 'y'],
    value = [1, 2, 3, 4, 5, 6]
))

┌─────┬──────────┬───────┐
│ grp ┆ subgroup ┆ value │
│ --- ┆ ---      ┆ ---   │
│ str ┆ str      ┆ i64   │
╞═════╪══════════╪═══════╡
│ A   ┆ x        ┆ 1     │
│ A   ┆ x        ┆ 2     │
│ A   ┆ y        ┆ 3     │
│ B   ┆ x        ┆ 4     │
│ B   ┆ x        ┆ 5     │
│ B   ┆ y        ┆ 6     │
└─────┴──────────┴───────┘

I'd like to calculate the mean value for each group, and also calculate the mean of the values where subgroup is "x". In polars, the most succinct way I've found to do this is:

(
    df
    .groupby(pl.col('grp'))
    .agg(
        pl.col('value').mean().alias('mean_all'),
        pl.when(pl.col('subgroup') == 'x').then(pl.col('value')).mean().alias('mean_x')
    )
)

┌─────┬──────────┬────────┐
│ grp ┆ mean_all ┆ mean_x │
│ --- ┆ ---      ┆ ---    │
│ str ┆ f64      ┆ f64    │
╞═════╪══════════╪════════╡
│ A   ┆ 2.0      ┆ 1.5    │
│ B   ┆ 5.0      ┆ 4.5    │
└─────┴──────────┴────────┘

This works fine, but the call to when...then inside agg produces a warning about this not being a valid aggregation (even though the chain is finished with mean). Is there a more idiomatic or elegant way to perform this operation from inside a polars chain?


Solution

  • You can filter within the groupby to select subgroup 'x'.

    (
        df
        .groupby(pl.col('grp'), maintain_order=True)
        .agg(
            pl.col('value').mean().alias('mean_all'),
            pl.col("value").filter(pl.col("subgroup") == "x").mean().alias("mean_x")
        )
    )
    

    This produces

    shape: (2, 3)
    ┌─────┬──────────┬────────┐
    │ grp ┆ mean_all ┆ mean_x │
    │ --- ┆ ---      ┆ ---    │
    │ str ┆ f64      ┆ f64    │
    ╞═════╪══════════╪════════╡
    │ A   ┆ 2.0      ┆ 1.5    │
    │ B   ┆ 5.0      ┆ 4.5    │
    └─────┴──────────┴────────┘