Using polars .rolling
and .agg
, how do I get the original column back, without having to join back with the original column, or without having to use .over
?
Example:
import polars as pl
dates = [
"2020-01-01 13:45:48",
"2020-01-01 16:42:13",
"2020-01-01 16:45:09",
"2020-01-02 18:12:48",
"2020-01-03 19:45:32",
"2020-01-08 23:16:43",
]
df = pl.DataFrame({"dt": dates, "a": [3, 7, 5, 9, 2, 1]}).with_columns(
pl.col("dt").str.strptime(pl.Datetime).set_sorted()
)
Provides me with a small polars dataframe:
dt | a | |
---|---|---|
0 | 2020-01-01 13:45:48 | 3 |
1 | 2020-01-01 16:42:13 | 7 |
2 | 2020-01-01 16:45:09 | 5 |
3 | 2020-01-02 18:12:48 | 9 |
4 | 2020-01-03 19:45:32 | 2 |
5 | 2020-01-08 23:16:43 | 1 |
When I apply a rolling aggregations, I get the new columns back, but not the original columns:
out = df.rolling(index_column="dt", period="2d").agg(
[
pl.sum("a").alias("sum_a"),
pl.min("a").alias("min_a"),
pl.max("a").alias("max_a"),
pl.col('a')
]
)
which gives:
dt | sum_a | min_a | max_a | a | |
---|---|---|---|---|---|
0 | 2020-01-01 13:45:48 | 3 | 3 | 3 | [3] |
1 | 2020-01-01 16:42:13 | 10 | 3 | 7 | [3 7] |
2 | 2020-01-01 16:45:09 | 15 | 3 | 7 | [3 7 5] |
3 | 2020-01-02 18:12:48 | 24 | 3 | 9 | [3 7 5 9] |
4 | 2020-01-03 19:45:32 | 11 | 2 | 9 | [9 2] |
5 | 2020-01-08 23:16:43 | 1 | 1 | 1 | [1] |
How can I get the original a column. I don't want to join and I don't want to use .over as I need the group_by of the rolling later on and .over does not work with .rolling
Edit. I am also not keen on using the following.
out = df.rolling(index_column="dt", period="2d").agg(
[
pl.sum("a").alias("sum_a"),
pl.min("a").alias("min_a"),
pl.max("a").alias("max_a"),
pl.col('a').last().alias('a')
]
)
Edit 2. Why Expr.rolling() is not feasible and why I need the group_by:
Given a more elaborate example:
dates = [
"2020-01-01 13:45:48",
"2020-01-01 16:42:13",
"2020-01-01 16:45:09",
"2020-01-02 18:12:48",
"2020-01-03 19:45:32",
"2020-01-08 23:16:43",
]
df_a = pl.DataFrame({"dt": dates, "a": [3, 7, 5, 9, 2, 1],'cat':['one']*6}).with_columns(
pl.col("dt").str.strptime(pl.Datetime).set_sorted()
)
df_b = pl.DataFrame({"dt": dates, "a": [3, 7, 5, 9, 2, 1],'cat':['two']*6}).with_columns(
pl.col("dt").str.strptime(pl.Datetime).set_sorted()
)
df = pl.concat([df_a,df_b])
dt | a | cat | |
---|---|---|---|
0 | 2020-01-01 13:45:48 | 3 | one |
1 | 2020-01-01 16:42:13 | 7 | one |
2 | 2020-01-01 16:45:09 | 5 | one |
3 | 2020-01-02 18:12:48 | 9 | one |
4 | 2020-01-03 19:45:32 | 2 | one |
5 | 2020-01-08 23:16:43 | 1 | one |
6 | 2020-01-01 13:45:48 | 3 | two |
7 | 2020-01-01 16:42:13 | 7 | two |
8 | 2020-01-01 16:45:09 | 5 | two |
9 | 2020-01-02 18:12:48 | 9 | two |
10 | 2020-01-03 19:45:32 | 2 | two |
11 | 2020-01-08 23:16:43 | 1 | two |
and the code:
out = df.rolling(index_column="dt", period="2d",group_by='cat').agg(
[
pl.sum("a").alias("sum_a"),
pl.min("a").alias("min_a"),
pl.max("a").alias("max_a"),
pl.col('a')
]
)
cat | dt | sum_a | min_a | max_a | a | |
---|---|---|---|---|---|---|
0 | one | 2020-01-01 13:45:48 | 3 | 3 | 3 | [3] |
1 | one | 2020-01-01 16:42:13 | 10 | 3 | 7 | [3 7] |
2 | one | 2020-01-01 16:45:09 | 15 | 3 | 7 | [3 7 5] |
3 | one | 2020-01-02 18:12:48 | 24 | 3 | 9 | [3 7 5 9] |
4 | one | 2020-01-03 19:45:32 | 11 | 2 | 9 | [9 2] |
5 | one | 2020-01-08 23:16:43 | 1 | 1 | 1 | [1] |
6 | two | 2020-01-01 13:45:48 | 3 | 3 | 3 | [3] |
7 | two | 2020-01-01 16:42:13 | 10 | 3 | 7 | [3 7] |
8 | two | 2020-01-01 16:45:09 | 15 | 3 | 7 | [3 7 5] |
9 | two | 2020-01-02 18:12:48 | 24 | 3 | 9 | [3 7 5 9] |
10 | two | 2020-01-03 19:45:32 | 11 | 2 | 9 | [9 2] |
11 | two | 2020-01-08 23:16:43 | 1 | 1 | 1 | [1] |
This does not work:
df.sort('dt').with_columns(sum=pl.sum("a").rolling(index_column="dt", period="2d").over(["cat"]))
Gives:
InvalidOperationError: rolling expression not allowed in aggregation
There are dedicated rolling_*_by
expressions which can be used with .over()
df.with_columns(
pl.col("a").rolling_sum_by("dt", "2d").over("cat").name.prefix("sum_"),
pl.col("a").rolling_min_by("dt", "2d").over("cat").name.prefix("min_"),
pl.col("a").rolling_max_by("dt", "2d").over("cat").name.prefix("max_")
)
shape: (12, 6)
┌─────────────────────┬─────┬─────┬───────┬───────┬───────┐
│ dt ┆ a ┆ cat ┆ sum_a ┆ min_a ┆ max_a │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ datetime[μs] ┆ i64 ┆ str ┆ i64 ┆ i64 ┆ i64 │
╞═════════════════════╪═════╪═════╪═══════╪═══════╪═══════╡
│ 2020-01-01 13:45:48 ┆ 3 ┆ one ┆ 3 ┆ 3 ┆ 3 │
│ 2020-01-01 16:42:13 ┆ 7 ┆ one ┆ 10 ┆ 3 ┆ 7 │
│ 2020-01-01 16:45:09 ┆ 5 ┆ one ┆ 15 ┆ 3 ┆ 7 │
│ 2020-01-02 18:12:48 ┆ 9 ┆ one ┆ 24 ┆ 3 ┆ 9 │
│ 2020-01-03 19:45:32 ┆ 2 ┆ one ┆ 11 ┆ 2 ┆ 9 │
│ … ┆ … ┆ … ┆ … ┆ … ┆ … │
│ 2020-01-01 16:42:13 ┆ 7 ┆ two ┆ 10 ┆ 3 ┆ 7 │
│ 2020-01-01 16:45:09 ┆ 5 ┆ two ┆ 15 ┆ 3 ┆ 7 │
│ 2020-01-02 18:12:48 ┆ 9 ┆ two ┆ 24 ┆ 3 ┆ 9 │
│ 2020-01-03 19:45:32 ┆ 2 ┆ two ┆ 11 ┆ 2 ┆ 9 │
│ 2020-01-08 23:16:43 ┆ 1 ┆ two ┆ 1 ┆ 1 ┆ 1 │
└─────────────────────┴─────┴─────┴───────┴───────┴───────┘