Search code examples
python-polars

polars df.sum(axis=1) equivalent for Expressions (and other functions, maybe median?)


I'm trying to find a polaric way of aggregating data per row. It's not strictly about .sum function, it's about all aggregations where axis makes sense.

Take a look at these pandas examples:

import pandas as pd

df = pd.DataFrame([[1, 2, 1], [4, 5, 3], [1, 8, 9]])

df[df.sum(axis=1) > 5]
#    0  1  2
# 1  4  5  3
# 2  1  8  9

df.assign(median=df.median(axis=1))
#    0  1  2  median
# 0  1  2  1     1.0
# 1  4  5  3     4.0
# 2  1  8  9     8.0

df.T.rolling(3).mean().T # EDIT: rolling(..., axis=1) has been deprecated
#     0   1         2
# 0 NaN NaN  1.333333
# 1 NaN NaN  4.000000
# 2 NaN NaN  6.000000

So the question is: how to deal with these situations using the Polars API?


Solution

  • As I understand, the order in terms of efficiency for row/horizontal methods in Polars is currently:

    1. dedicated .*_horizontal methods
    2. dedicated .list.* methods
    3. .list.eval()

    For your specific examples:

    horizontal

    Firstly, we can use pl.sum_horizontal()

    df.filter(pl.sum_horizontal(pl.all()) > 5)
    
    shape: (2, 3)
    ┌─────┬─────┬─────┐
    │ 0   ┆ 1   ┆ 2   │
    │ --- ┆ --- ┆ --- │
    │ i64 ┆ i64 ┆ i64 │
    ╞═════╪═════╪═════╡
    │ 4   ┆ 5   ┆ 3   │
    │ 1   ┆ 8   ┆ 9   │
    └─────┴─────┴─────┘
    

    list

    There is no median_horizontal yet, but there is a .list.median()

    .concat_list() is used to create a list.

    df.with_columns(median = pl.concat_list(pl.all()).list.median())
    
    shape: (3, 4)
    ┌─────┬─────┬─────┬────────┐
    │ 0   ┆ 1   ┆ 2   ┆ median │
    │ --- ┆ --- ┆ --- ┆ ---    │
    │ i64 ┆ i64 ┆ i64 ┆ f64    │
    ╞═════╪═════╪═════╪════════╡
    │ 1   ┆ 2   ┆ 1   ┆ 1.0    │
    │ 4   ┆ 5   ┆ 3   ┆ 4.0    │
    │ 1   ┆ 8   ┆ 9   ┆ 8.0    │
    └─────┴─────┴─────┴────────┘
    

    list.eval()

    There are no horizontal or list rolling methods, but .list.eval() can be used to access the full expression set for lists.

    df.with_columns(
       pl.concat_list(pl.all()).list.eval(
          pl.element().rolling_mean(3).last() # only need the last value
       ) 
       .list.last() # turn [6.0] into 6.0
       .alias("rolling_mean")
    )
    
    shape: (3, 4)
    ┌─────┬─────┬─────┬──────────────┐
    │ 0   ┆ 1   ┆ 2   ┆ rolling_mean │
    │ --- ┆ --- ┆ --- ┆ ---          │
    │ i64 ┆ i64 ┆ i64 ┆ f64          │
    ╞═════╪═════╪═════╪══════════════╡
    │ 1   ┆ 2   ┆ 1   ┆ 1.333333     │
    │ 4   ┆ 5   ┆ 3   ┆ 4.0          │
    │ 1   ┆ 8   ┆ 9   ┆ 6.0          │
    └─────┴─────┴─────┴──────────────┘