Search code examples
pythonpython-polars

How to use polars.lit in group_by aggregation context


Basically, I want to calculate sum of powers for a constant.

As an illustration:

import polars as pl

c = 2

df = pl.DataFrame({"a": [1, 2, 3]})

df.select(pl.lit(c).pow(pl.col("a")).sum())
shape: (1, 1)
┌─────────┐
│ literal │
│ ---     │
│ i32     │
╞═════════╡
│ 14      │
└─────────┘

However when I try to do the same thing in a groupby agg context, I got an error:

import polars as pl

c = 2

df = pl.DataFrame({"a": [1, 2, 3, 1, 2, 3], "b": [1, 1, 1, 2, 2, 2]})

df.group_by("b").agg(pl.lit(c).pow(pl.col("a")).sum())
# Error originated in expression: '2i32.pow([col("a")])'

One (temp) way I can make it work is to append a column beforehand on df such as

df = df.with_columns(pl.lit(c).alias("c"))

But it is not clean as there may be an existing column with name "c" (or whatever alias I give) in original df, which could cause column name collision.

I am sure there could be better and cleaner way to to this. But how?


Solution

  • Update: The old behaviour was a bug. literals are now broadcasted as expected.

    df.group_by("b").agg(pl.lit(c).pow(pl.col("a")).sum())
    
    shape: (2, 2)
    ┌─────┬─────────┐
    │ b   ┆ literal │
    │ --- ┆ ---     │
    │ i64 ┆ i32     │
    ╞═════╪═════════╡
    │ 1   ┆ 14      │
    │ 2   ┆ 14      │
    └─────┴─────────┘
    

    It looks like you may need to .repeat_by to get the lengths to match.

    df.group_by("b").agg(
        pl.lit(c).repeat_by("a").pow(pl.col("a")).sum()
    )
    
    shape: (2, 2)
    ┌─────┬─────────┐
    │ b   ┆ literal │
    │ --- ┆ ---     │
    │ i64 ┆ f64     │
    ╞═════╪═════════╡
    │ 1   ┆ 14.0    │
    │ 2   ┆ 14.0    │
    └─────┴─────────┘