Search code examples
pythonpython-polars

Rolling aggregation in polars and also get the original column back without join or using .over


Using polars .rolling and .agg, how do I get the original column back, without having to join back with the original column, or without having to use .over?

Example:

import polars as pl

dates = [
    "2020-01-01 13:45:48",
    "2020-01-01 16:42:13",
    "2020-01-01 16:45:09",
    "2020-01-02 18:12:48",
    "2020-01-03 19:45:32",
    "2020-01-08 23:16:43",
]
df = pl.DataFrame({"dt": dates, "a": [3, 7, 5, 9, 2, 1]}).with_columns(
    pl.col("dt").str.strptime(pl.Datetime).set_sorted()
)

Provides me with a small polars dataframe:

dt a
0 2020-01-01 13:45:48 3
1 2020-01-01 16:42:13 7
2 2020-01-01 16:45:09 5
3 2020-01-02 18:12:48 9
4 2020-01-03 19:45:32 2
5 2020-01-08 23:16:43 1

When I apply a rolling aggregations, I get the new columns back, but not the original columns:

out = df.rolling(index_column="dt", period="2d").agg(
    [
        pl.sum("a").alias("sum_a"),
        pl.min("a").alias("min_a"),
        pl.max("a").alias("max_a"),
        pl.col('a')
    ]
)

which gives:

dt sum_a min_a max_a a
0 2020-01-01 13:45:48 3 3 3 [3]
1 2020-01-01 16:42:13 10 3 7 [3 7]
2 2020-01-01 16:45:09 15 3 7 [3 7 5]
3 2020-01-02 18:12:48 24 3 9 [3 7 5 9]
4 2020-01-03 19:45:32 11 2 9 [9 2]
5 2020-01-08 23:16:43 1 1 1 [1]

How can I get the original a column. I don't want to join and I don't want to use .over as I need the group_by of the rolling later on and .over does not work with .rolling

Edit. I am also not keen on using the following.

out = df.rolling(index_column="dt", period="2d").agg(
    [
        pl.sum("a").alias("sum_a"),
        pl.min("a").alias("min_a"),
        pl.max("a").alias("max_a"),
        pl.col('a').last().alias('a')
    ]
)

Edit 2. Why Expr.rolling() is not feasible and why I need the group_by:

Given a more elaborate example:

dates = [
    "2020-01-01 13:45:48",
    "2020-01-01 16:42:13",
    "2020-01-01 16:45:09",
    "2020-01-02 18:12:48",
    "2020-01-03 19:45:32",
    "2020-01-08 23:16:43",
]
df_a = pl.DataFrame({"dt": dates, "a": [3, 7, 5, 9, 2, 1],'cat':['one']*6}).with_columns(
    pl.col("dt").str.strptime(pl.Datetime).set_sorted()
)
df_b = pl.DataFrame({"dt": dates, "a": [3, 7, 5, 9, 2, 1],'cat':['two']*6}).with_columns(
    pl.col("dt").str.strptime(pl.Datetime).set_sorted()
)

df = pl.concat([df_a,df_b])
dt a cat
0 2020-01-01 13:45:48 3 one
1 2020-01-01 16:42:13 7 one
2 2020-01-01 16:45:09 5 one
3 2020-01-02 18:12:48 9 one
4 2020-01-03 19:45:32 2 one
5 2020-01-08 23:16:43 1 one
6 2020-01-01 13:45:48 3 two
7 2020-01-01 16:42:13 7 two
8 2020-01-01 16:45:09 5 two
9 2020-01-02 18:12:48 9 two
10 2020-01-03 19:45:32 2 two
11 2020-01-08 23:16:43 1 two

and the code:

    out = df.rolling(index_column="dt", period="2d",group_by='cat').agg(
    [
        pl.sum("a").alias("sum_a"),
        pl.min("a").alias("min_a"),
        pl.max("a").alias("max_a"),
        pl.col('a')
    ]
)
cat dt sum_a min_a max_a a
0 one 2020-01-01 13:45:48 3 3 3 [3]
1 one 2020-01-01 16:42:13 10 3 7 [3 7]
2 one 2020-01-01 16:45:09 15 3 7 [3 7 5]
3 one 2020-01-02 18:12:48 24 3 9 [3 7 5 9]
4 one 2020-01-03 19:45:32 11 2 9 [9 2]
5 one 2020-01-08 23:16:43 1 1 1 [1]
6 two 2020-01-01 13:45:48 3 3 3 [3]
7 two 2020-01-01 16:42:13 10 3 7 [3 7]
8 two 2020-01-01 16:45:09 15 3 7 [3 7 5]
9 two 2020-01-02 18:12:48 24 3 9 [3 7 5 9]
10 two 2020-01-03 19:45:32 11 2 9 [9 2]
11 two 2020-01-08 23:16:43 1 1 1 [1]

This does not work:

df.sort('dt').with_columns(sum=pl.sum("a").rolling(index_column="dt", period="2d").over(["cat"]))

Gives:

InvalidOperationError: rolling expression not allowed in aggregation

Solution

  • There are dedicated rolling_*_by expressions which can be used with .over()

    df.with_columns(
        pl.col("a").rolling_sum_by("dt", "2d").over("cat").name.prefix("sum_"),
        pl.col("a").rolling_min_by("dt", "2d").over("cat").name.prefix("min_"),
        pl.col("a").rolling_max_by("dt", "2d").over("cat").name.prefix("max_")
    )
    
    shape: (12, 6)
    ┌─────────────────────┬─────┬─────┬───────┬───────┬───────┐
    │ dt                  ┆ a   ┆ cat ┆ sum_a ┆ min_a ┆ max_a │
    │ ---                 ┆ --- ┆ --- ┆ ---   ┆ ---   ┆ ---   │
    │ datetime[μs]        ┆ i64 ┆ str ┆ i64   ┆ i64   ┆ i64   │
    ╞═════════════════════╪═════╪═════╪═══════╪═══════╪═══════╡
    │ 2020-01-01 13:45:48 ┆ 3   ┆ one ┆ 3     ┆ 3     ┆ 3     │
    │ 2020-01-01 16:42:13 ┆ 7   ┆ one ┆ 10    ┆ 3     ┆ 7     │
    │ 2020-01-01 16:45:09 ┆ 5   ┆ one ┆ 15    ┆ 3     ┆ 7     │
    │ 2020-01-02 18:12:48 ┆ 9   ┆ one ┆ 24    ┆ 3     ┆ 9     │
    │ 2020-01-03 19:45:32 ┆ 2   ┆ one ┆ 11    ┆ 2     ┆ 9     │
    │ …                   ┆ …   ┆ …   ┆ …     ┆ …     ┆ …     │
    │ 2020-01-01 16:42:13 ┆ 7   ┆ two ┆ 10    ┆ 3     ┆ 7     │
    │ 2020-01-01 16:45:09 ┆ 5   ┆ two ┆ 15    ┆ 3     ┆ 7     │
    │ 2020-01-02 18:12:48 ┆ 9   ┆ two ┆ 24    ┆ 3     ┆ 9     │
    │ 2020-01-03 19:45:32 ┆ 2   ┆ two ┆ 11    ┆ 2     ┆ 9     │
    │ 2020-01-08 23:16:43 ┆ 1   ┆ two ┆ 1     ┆ 1     ┆ 1     │
    └─────────────────────┴─────┴─────┴───────┴───────┴───────┘