Search code examples
python-polars

Polars aggregation warning using when->then


Update: This has been resolved and no longer issues a warning.


Consider the following:

df = pl.from_repr("""
┌─────┬─────┐
│ a   ┆ b   │
│ --- ┆ --- │
│ i64 ┆ i64 │
╞═════╪═════╡
│ 1   ┆ 1   │
│ 2   ┆ 1   │
└─────┴─────┘
""")
In [10]: df.group_by("a").agg(pl.when(pl.col("b") == 1).then(pl.col("b")))
The predicate '[(col("b")) == (1)]' in 'when->then->otherwise' is not a valid aggregation and might produce a different number of rows than the groupby operation would. This behavior is experimental and may be subject to change
Out[10]: 
shape: (2, 2)
┌─────┬───────────┐
│ a   ┆ b         │
│ --- ┆ ---       │
│ i64 ┆ list[i64] │
╞═════╪═══════════╡
│ 2   ┆ [1]       │
│ 1   ┆ [1]       │
└─────┴───────────┘

Is there something to worry about? The when->then has to produce a value even if it's null.


Solution

  • The problem in the aggregation arises from comparing the group "B" with a literal 1 and replacing with the group "B". We haven't formalized the vectorization rules for that expression yet, hence the warning.

    It is more explicit and (easier for use to understand), to apply your ternary expression in a with_columns and then do an aggregation:

    df = pl.from_repr("""shape: (2, 2)
    ┌─────┬─────┐
    │ a   ┆ b   │
    │ --- ┆ --- │
    │ i64 ┆ i64 │
    ╞═════╪═════╡
    │ 1   ┆ 1   │
    │ 2   ┆ 1   │
    └─────┴─────┘
    """)
    
    (df.with_columns(
        pl.when(pl.col("b") == 1).then(pl.col("b"))
    ).group_by("a").all())
    
    shape: (2, 2)
    ┌─────┬───────────┐
    │ a   ┆ b         │
    │ --- ┆ ---       │
    │ i64 ┆ list[i64] │
    ╞═════╪═══════════╡
    │ 1   ┆ [1]       │
    │ 2   ┆ [1]       │
    └─────┴───────────┘