What's the right way to perform a group_by + rolling aggregate operation in polars? For some reason performing an ewm_mean
over a rolling groupby gives me the list of all the ewm's rolling by time. For example take the dataframe below:
portfolios = pl.from_repr("""
┌─────────────────────┬────────┬───────────┐
│ ts ┆ symbol ┆ signal_0 │
│ --- ┆ --- ┆ --- │
│ datetime[μs] ┆ str ┆ f64 │
╞═════════════════════╪════════╪═══════════╡
│ 2022-02-14 09:20:00 ┆ A ┆ -1.704301 │
│ 2022-02-14 09:20:00 ┆ AA ┆ -1.181743 │
│ 2022-02-14 09:50:00 ┆ A ┆ 1.040125 │
│ 2022-02-14 09:50:00 ┆ AA ┆ 0.776798 │
│ 2022-02-14 10:20:00 ┆ A ┆ 1.934686 │
│ 2022-02-14 10:20:00 ┆ AA ┆ 1.480892 │
│ 2022-02-14 10:50:00 ┆ A ┆ 2.073418 │
│ 2022-02-14 10:50:00 ┆ AA ┆ 1.623698 │
│ 2022-02-14 11:20:00 ┆ A ┆ 2.088835 │
│ 2022-02-14 11:20:00 ┆ AA ┆ 1.741544 │
└─────────────────────┴────────┴───────────┘
""")
Here, I want to group by symbol and get the rolling mean for signal_0 at every timestamp. Unfortunately this doesn't work:
portfolios.rolling("ts", group_by="symbol", period="1d").agg(
pl.col("signal_0").ewm_mean(half_life=0.1).alias(f"signal_0_mean")
)
shape: (10, 3)
┌────────┬─────────────────────┬─────────────────────────────────┐
│ symbol ┆ ts ┆ signal_0_mean │
│ --- ┆ --- ┆ --- │
│ str ┆ datetime[μs] ┆ list[f64] │
╞════════╪═════════════════════╪═════════════════════════════════╡
│ A ┆ 2022-02-14 09:20:00 ┆ [-1.704301] │
│ A ┆ 2022-02-14 09:50:00 ┆ [-1.704301, 1.037448] │
│ A ┆ 2022-02-14 10:20:00 ┆ [-1.704301, 1.037448, 1.93381] │
│ A ┆ 2022-02-14 10:50:00 ┆ [-1.704301, 1.037448, … 2.0732… │
│ A ┆ 2022-02-14 11:20:00 ┆ [-1.704301, 1.037448, … 2.0888… │
│ AA ┆ 2022-02-14 09:20:00 ┆ [-1.181743] │
│ AA ┆ 2022-02-14 09:50:00 ┆ [-1.181743, 0.774887] │
│ AA ┆ 2022-02-14 10:20:00 ┆ [-1.181743, 0.774887, 1.480203… │
│ AA ┆ 2022-02-14 10:50:00 ┆ [-1.181743, 0.774887, … 1.6235… │
│ AA ┆ 2022-02-14 11:20:00 ┆ [-1.181743, 0.774887, … 1.7414… │
└────────┴─────────────────────┴─────────────────────────────────┘
If I wanted to do this in pandas, I would write:
portfolios.to_pandas().set_index(["ts", "symbol"]).groupby(level=1)["signal_0"].transform(
lambda x: x.ewm(halflife=10).mean()
)
Which would yield:
ts symbol
2022-02-14 09:20:00 A -1.704301
AA -1.181743
2022-02-14 09:50:00 A -0.284550
AA -0.168547
2022-02-14 10:20:00 A 0.507021
AA 0.419785
2022-02-14 10:50:00 A 0.940226
AA 0.752741
2022-02-14 11:20:00 A 1.202843
AA 0.978820
Name: signal_0, dtype: float64
You were close. Since ewm_mean
produces an estimate for each observation in each window, you simply need to specify that you want the last
calculated value in each rolling window.
(
portfolios
.rolling("ts", group_by="symbol", period="1d")
.agg(
pl.col("signal_0").ewm_mean(half_life=10).last().alias(f"signal_0_mean")
)
.sort('ts', 'symbol')
)
shape: (10, 3)
┌────────┬─────────────────────┬───────────────┐
│ symbol ┆ ts ┆ signal_0_mean │
│ --- ┆ --- ┆ --- │
│ str ┆ datetime[μs] ┆ f64 │
╞════════╪═════════════════════╪═══════════════╡
│ A ┆ 2022-02-14 09:20:00 ┆ -1.704301 │
│ AA ┆ 2022-02-14 09:20:00 ┆ -1.181743 │
│ A ┆ 2022-02-14 09:50:00 ┆ -0.28455 │
│ AA ┆ 2022-02-14 09:50:00 ┆ -0.168547 │
│ A ┆ 2022-02-14 10:20:00 ┆ 0.507021 │
│ AA ┆ 2022-02-14 10:20:00 ┆ 0.419785 │
│ A ┆ 2022-02-14 10:50:00 ┆ 0.940226 │
│ AA ┆ 2022-02-14 10:50:00 ┆ 0.752741 │
│ A ┆ 2022-02-14 11:20:00 ┆ 1.202844 │
│ AA ┆ 2022-02-14 11:20:00 ┆ 0.97882 │
└────────┴─────────────────────┴───────────────┘