Search code examples
pythonpython-polars

Differentiation of function return polars expression and python udf wrt input parameters


I am trying to understand the performance implications of using polars.struct compared to polars.col when performing a weighted sum.

Consider the following minimal example


def weighted_sum(expr: pl.Expr) -> pl.Expr:
    return (expr.struct.field("value") * expr.struct.field("weight")).sum() / expr.struct.field("weight").sum()

if __name__ == "__main__":
    size = 10_000_000
    frame = pl.from_dict({"weight": np.random.rand(size), "value": np.random.randn(size)}).lazy()

    t0 = time.time()
    frame.select(pl.struct("weight", "value").pipe(weighted_sum)).collect()
    t = time.time() - t0
    print(f"Time for struct based weighted sum {t:.3f}s.")

    t0 = time.time()
    frame.select((pl.col("value") * pl.col("weight")).sum() / pl.col("weight").sum()).collect()
    t = time.time() - t0
    print(f"Time for col based weighted sum {t:.3f}s.")

The latter (usual column based expression) is about 3 times faster than the former one - both are really fast given that we're aggregating 10M rows here. The relative performance is similar for 100M rows and for 1G rows the latter is about 6 times faster than the struct based version.

I'd like to understand what's going on?

Also what would be the canonical way to define a custom expression / aggreagtion for a weighted sum that indicates that 2 arguments are needed, e.g def weighted_sum(values, weights). For this definition I am not sure about the correct type for values and weights would that be polars.Expr or polars.Series and how should I invoke this.

I am primarily interested because I want to increase readability of code that computes a weighted sum in groupby aggregations, e.g.

frame.groupby(...).agg((pl.col("value") * pl.col("weight")).sum() / pl.col("weight").sum())

which could possible be shortened to something like

frame.groupby(...).agg(pl.col("value", "weight").pipe(weighted_sum))

With the struct based aggregation I can already do

frame.groupby(...).agg(pl.struct("value", "weight").pipe(weighted_sum))

but there seems to be a performance overhead associated with it, that I would like to understand, and possibly avoid. Ideally, without having to sacrifice readability by defining complex expressions involving multiple columns as custom expressions or functions.


Solution

  • I think there may be some confusion between a python udf and a function which returns an expression. In the former case, if the python udf needs access to multiple values then the only way is to roll them up in a struct. In the latter case, that restriction doesn't apply because the function is just shorthand to typing the expression out.

    As such you can do

    def weighted_sum(value:str, weight:str) -> pl.Expr:
        return (pl.col(value) * pl.col(weight)).sum() / pl.col(weight).sum()
    

    You could make it more complex/flexible as well for instance:

    def weighted_sum(value:str|pl.Expr, weight:str|pl.Expr) -> pl.Expr:
        if isinstance(value, str):
            value=pl.col(value)
        if isinstance(weight, str):
            weight=pl.col(weight)
        return (value * weight).sum() / weight.sum()
    

    which would then be used like:

    frame.groupby(...).agg(weighted_sum("value", "weight"))
    

    or

    frame.groupby(...).agg(weighted_sum(pl.col("value"), pl.col("weight")))
    

    There's also some confusion on what it means to have

    frame.groupby(...).agg(pl.col("value", "weight").pipe(weighted_sum))
    

    When you put multiple column names in the pl.col then it is a short cut to doing (in this case)

    frame.groupby(...).agg(pl.col("value").pipe(weighted_sum),
                           pl.col("weight").pipe(weighted_sum))
    

    It doesn't have the means to know that your function wants two input parameters.