Search code examples
python-polars

polars group by returning a value when filters does not match


let's say I have this polars dataframe

df = pl.DataFrame(
   {
      "group_col": ["g1", "g1", "g2"],
      "b": [1, 2, 3],
      "c": [4, 5, 6]
   }
)

output

shape: (3, 3)
┌───────────┬─────┬─────┐
│ group_col ┆ b   ┆ c   │
│ ---       ┆ --- ┆ --- │
│ str       ┆ i64 ┆ i64 │
╞═══════════╪═════╪═════╡
│ g1        ┆ 1   ┆ 4   │
│ g1        ┆ 2   ┆ 5   │
│ g2        ┆ 3   ┆ 6   │
└───────────┴─────┴─────┘

i need to do a group by like this:

df.group_by("group_col").agg(
    [
        pl.col("c").filter((pl.col("b") >= 1) & (pl.col('b').max() == pl.col('b'))).max().alias("gte"),
    ]
)

output:

shape: (2, 2)
┌───────────┬─────┐
│ group_col ┆ gte │
│ ---       ┆ --- │
│ str       ┆ i64 │
╞═══════════╪═════╡
│ g1        ┆ 5   │
│ g2        ┆ 6   │
└───────────┴─────┘

in this case, why the value in column gte in the group g1 is not null? given the & (pl.col('b').max() == pl.col('b')) filter?


Solution

  • Whatever the group_by is limits what the agg sees to just what is in each group.

    Try this:

    (
        df
        .group_by('group_col')
        .agg(
            pl.all().max().suffix("max"), 
            pl.all().min().suffix("min")
            )
    )
    shape: (2, 5)
    ┌───────────┬──────┬──────┬──────┬──────┐
    │ group_col ┆ bmax ┆ cmax ┆ bmin ┆ cmin │
    │ ---       ┆ ---  ┆ ---  ┆ ---  ┆ ---  │
    │ str       ┆ i64  ┆ i64  ┆ i64  ┆ i64  │
    ╞═══════════╪══════╪══════╪══════╪══════╡
    │ g1        ┆ 2    ┆ 5    ┆ 1    ┆ 4    │
    │ g2        ┆ 3    ┆ 6    ┆ 3    ┆ 6    │
    └───────────┴──────┴──────┴──────┴──────┘
    

    As you can see, for group g1, the cmax is 5 because that's the max in that group's context. So when you get gte as 5 for g1 that's because column b is at its max of 2 for that group.

    It seems like what you want is something like:

    (
        df
        .filter((pl.col('b').max() == pl.col('b')))
        .group_by("group_col")
        .agg( 
            pl.col("c").filter(pl.col("b") >= 1).max().alias("gte"),
        
        )
    )
    shape: (1, 2)
    ┌───────────┬─────┐
    │ group_col ┆ gte │
    │ ---       ┆ --- │
    │ str       ┆ i64 │
    ╞═══════════╪═════╡
    │ g2        ┆ 6   │
    └───────────┴─────┘
    

    but that you want g1 to be returned with a null value. To achieve that you'd have to setup a df with your default rows and then left join the preceding result to that such as:

    (
        df.select('group_col').unique()
        .join(
            df
            .filter((pl.col('b').max() == pl.col('b')) & (pl.col("b") >= 1))
            .group_by("group_col",maintain_order=True)
            .agg( 
                pl.col("c").max().alias("gte"),
            
            ),
        on='group_col', how='left'
        )
        .sort('group_col')
    )
    shape: (2, 2)
    ┌───────────┬──────┐
    │ group_col ┆ gte  │
    │ ---       ┆ ---  │
    │ str       ┆ i64  │
    ╞═══════════╪══════╡
    │ g1        ┆ null │
    │ g2        ┆ 6    │
    └───────────┴──────┘