Search code examples
pythonpython-polars

How to do if and else in Polars group_by context


Update: The vectorization rules have since been formalized. The query runs as expected without warning.


For a dataframe, the goal is to have the mean of a column - a group_by another column - b given the first value of a in the group is not null, if it is, just return null.

The sample dataframe

df = pl.DataFrame({"a": [None, 1, 2, 3, 4], "b": [1, 1, 2, 2, 2]})

I tried something like

df.group_by("b").agg(
    pl.when(pl.col("a").first().is_null()).then(None).otherwise(pl.mean("a"))
)

The results are as expected but get a warning saying when may not be guaranteed to do its job in group_by context.

The predicate 'col("a").first().is_null()' in 'when->then->otherwise' is not a valid aggregation and might produce a different number of rows than the groupby operation would. This behavior is experimental and may be subject to change
shape: (2, 2)
┌─────┬─────────┐
│ b   ┆ literal │
│ --- ┆ ---     │
│ i64 ┆ f64     │
╞═════╪═════════╡
│ 1   ┆ null    │
│ 2   ┆ 3.0     │
└─────┴─────────┘

May I know why and what could be a better alternative way to do if-else in group_by?


Solution

  • You can use:

    • pl.col("a").is_null().first()

    instead of:

    • pl.col("a").first().is_null()

    If we look at both approaches:

    df.group_by("b", maintain_order=True).agg(
       pl.col("a"),
       pl.col("a").is_not_null().alias("yes"),
       pl.col("a").first().is_not_null().alias("no"),
    )
    
    shape: (2, 4)
    ┌─────┬───────────┬────────────────────┬───────┐
    │ b   ┆ a         ┆ yes                ┆ no    │
    │ --- ┆ ---       ┆ ---                ┆ ---   │
    │ i64 ┆ list[i64] ┆ list[bool]         ┆ bool  │
    ╞═════╪═══════════╪════════════════════╪═══════╡
    │ 1   ┆ [null, 1] ┆ [false, true]      ┆ false │
    │ 2   ┆ [2, 3, 4] ┆ [true, true, true] ┆ true  │
    └─────┴───────────┴────────────────────┴───────┘
    

    My understanding is, that in the no case, only null and 2 are passed to .is_not_null() - the rest of the inputs have been "silently discarded".

    polars knows a has lengths of 2 and 3 and expects "boolean masks" of the same length.

    We can take the .first() value of yes which has the same end result:

    df.group_by("b", maintain_order=True).agg(
       pl.col("a"),
       pl.col("a").is_not_null().first().alias("yes"),
       pl.col("a").first().is_not_null().alias("no"),
    )
    
    shape: (2, 4)
    ┌─────┬───────────┬───────┬───────┐
    │ b   ┆ a         ┆ yes   ┆ no    │
    │ --- ┆ ---       ┆ ---   ┆ ---   │
    │ i64 ┆ list[i64] ┆ bool  ┆ bool  │
    ╞═════╪═══════════╪═══════╪═══════╡
    │ 1   ┆ [null, 1] ┆ false ┆ false │
    │ 2   ┆ [2, 3, 4] ┆ true  ┆ true  │
    └─────┴───────────┴───────┴───────┘
    

    But now all of the inputs have been passed to .is_not_null() and the length check passes.