Search code examples
pythonpython-polars

Custom `polars` expression involving multiple columns and filtering


Consider the following function acting on a polars.DataFrame

import polars as pl


def f(frame: pl.DataFrame) -> pl.DataFrame:
    """Select columns 'A', 'B' form `frame` and return only those rows for which col. 'C' is greater than 7."""
    return frame.filter(pl.col("C") > 7.0).select("A", "B")


if __name__ == "__main__":
    frame = pl.from_dict({"A": [1, 2, 3], "B": [4, 5, 6], "C": [7, 8, 9]})
    result = frame.pipe(f)
    print(result)
    >>> shape: (2, 2)
┌─────┬─────┐
│ A   ┆ B   │
│ --- ┆ --- │
│ i64 ┆ i64 │
╞═════╪═════╡
│ 2   ┆ 5   │
│ 3   ┆ 6   │
└─────┴─────┘

In order to abstract away from specific column names I can write this as

import polars as pl


def g(a: pl.Series, b: pl.Series, c: pl.Series) -> pl.DataFrame:
    frame = pl.concat([a.to_frame(), b.to_frame(), c.to_frame()], how="horizontal")
    return frame.filter(pl.col("C") > 7.0).select("A", "B")

if __name__ == "__main__":
    frame = pl.from_dict({"A": [1, 2, 3], "B": [4, 5, 6], "C": [7, 8, 9]})
    result = g(frame.get_column("A"), frame.get_column("B"), frame.get_column("C"))    
    print(result)
>>> shape: (2, 2)
┌─────┬─────┐
│ A   ┆ B   │
│ --- ┆ --- │
│ i64 ┆ i64 │
╞═════╪═════╡
│ 2   ┆ 5   │
│ 3   ┆ 6   │
└─────┴─────┘

This somehow seems cumbersome as it involves extracting pl.Series from pl.DataFrame followed by pl.concat and multiple recasts to pl.DataFrame.

Would it be possible to write this as a custom expression instead? I'd like to apply it as outlined below

import polars as pl


def h(a: pl.Expr, b: pl.Expr, c: pl.Expr) -> pl.Expr:
    # How to represent f (or g) in terms of only `pl.Expr`?
    pass


if __name__ == "__main__":
    frame = pl.from_dict({"A": [1, 2, 3], "B": [4, 5, 6], "C": [7, 8, 9]})
    result = frame.select(h(pl.col("A"), pl.col("B"), pl.col("C")))
    print(result)
>>> shape: (2, 2)
┌─────┬─────┐
│ A   ┆ B   │
│ --- ┆ --- │
│ i64 ┆ i64 │
╞═════╪═════╡
│ 2   ┆ 5   │
│ 3   ┆ 6   │
└─────┴─────┘

I am not sure if that is the intended usage of the pl.Expr type. It seems to be a nice way to not rely on specific column names (although implicitly requiring that those expressions evaluate to 'series' with identical length and appropriate data type for the expression to work).


Solution

  • As you're passing the result to DataFrame.select() - you can use Expr.filter() to build the expressions.

    def my_func(a, b, c): 
        return a.filter(c), b.filter(c)
    
    df = pl.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6], "C": [7, 8, 9]})
    
    df.select(my_func(pl.col("A"), pl.col("B"), pl.col("C") > 7))
    
    shape: (2, 2)
    ┌─────┬─────┐
    │ A   ┆ B   │
    │ --- ┆ --- │
    │ i64 ┆ i64 │
    ╞═════╪═════╡
    │ 2   ┆ 5   │
    │ 3   ┆ 6   │
    └─────┴─────┘