Search code examples
pythonpython-polars

How to pass aggregation functions as function argument in Polars?


How can we pass aggregation functions as argument to a custom aggregation function in Polars? You should be able to pass a single function for all columns or a dictionary if you have different aggregations by column.

import polars as pl

# Sample DataFrame
df = pl.DataFrame({
    "category": ["A", "A", "B", "B", "B"],
    "value": [1, 2, 3, 4, 5]
})

def agg_with_sum(df: pl.DataFrame | pl.LazyFrame) -> pl.DataFrame | pl.LazyFrame:
    return df.group_by("category").agg(pl.col("*").sum())

# Custom function to perform aggregation
def agg_with_expr(df: pl.DataFrame | pl.LazyFrame,
                  agg_expr: pl.Expr | dict[str, pl.Expr]) -> pl.DataFrame | pl.LazyFrame:
    if isinstance(agg_expr, dict):
        return df.group_by("category").agg([pl.col(col).aggexpr() for col, aggexpr in agg_expr.items()])
    return df.group_by("category").agg(pl.col("*").agg_expr())

# Trying to pass a Polars expression for sum aggregation
print(agg_with_sum(df))
# ┌──────────┬───────┐
# │ category ┆ value │
# │ ---      ┆ ---   │
# │ str      ┆ i64   │
# ╞══════════╪═══════╡
# │ A        ┆ 3     │
# │ B        ┆ 12    │
# └──────────┴───────┘

# Trying to pass a custom Polars expression
print(agg_with_expr(df, pl.sum))
# AttributeError: 'Expr' object has no attribute 'agg_expr'

print(agg_with_expr(df, {'value': pl.sum}))
# AttributeError: 'Expr' object has no attribute 'aggexpr'

Solution

  • You can pass it as anonymous function with expression as parameter (I simplified your example just to illustrate the point):

    def agg_with_expr(df, agg_expr):
        return df.group_by("category").agg(agg_expr(pl.col("*")))
    
    agg_with_expr(df, lambda x: x.sum())
    
    shape: (2, 2)
    ┌──────────┬───────┐
    │ category ┆ value │
    │ ---      ┆ ---   │
    │ str      ┆ i64   │
    ╞══════════╪═══════╡
    │ B        ┆ 12    │
    │ A        ┆ 3     │
    └──────────┴───────┘
    

    update. as @orlp mentioned in comments, in this particular case you could do it without anonymous function, with plain usage of pl.Expr.sum(), which is much more neat.

    agg_with_expr(df, pl.Expr.sum)
    
    shape: (2, 2)
    ┌──────────┬───────┐
    │ category ┆ value │
    │ ---      ┆ ---   │
    │ str      ┆ i64   │
    ╞══════════╪═══════╡
    │ A        ┆ 3     │
    │ B        ┆ 12    │
    └──────────┴───────┘