Search code examples
pythonpython-polars

Polars - How to compute rolling ewm grouped by column?


What's the right way to perform a group_by + rolling aggregate operation in polars? For some reason performing an ewm_mean over a rolling groupby gives me the list of all the ewm's rolling by time. For example take the dataframe below:

portfolios = pl.from_repr("""
┌─────────────────────┬────────┬───────────┐
│ ts                  ┆ symbol ┆ signal_0  │
│ ---                 ┆ ---    ┆ ---       │
│ datetime[μs]        ┆ str    ┆ f64       │
╞═════════════════════╪════════╪═══════════╡
│ 2022-02-14 09:20:00 ┆ A      ┆ -1.704301 │
│ 2022-02-14 09:20:00 ┆ AA     ┆ -1.181743 │
│ 2022-02-14 09:50:00 ┆ A      ┆ 1.040125  │
│ 2022-02-14 09:50:00 ┆ AA     ┆ 0.776798  │
│ 2022-02-14 10:20:00 ┆ A      ┆ 1.934686  │
│ 2022-02-14 10:20:00 ┆ AA     ┆ 1.480892  │
│ 2022-02-14 10:50:00 ┆ A      ┆ 2.073418  │
│ 2022-02-14 10:50:00 ┆ AA     ┆ 1.623698  │
│ 2022-02-14 11:20:00 ┆ A      ┆ 2.088835  │
│ 2022-02-14 11:20:00 ┆ AA     ┆ 1.741544  │
└─────────────────────┴────────┴───────────┘
""")

Here, I want to group by symbol and get the rolling mean for signal_0 at every timestamp. Unfortunately this doesn't work:

portfolios.rolling("ts", group_by="symbol", period="1d").agg(
    pl.col("signal_0").ewm_mean(half_life=0.1).alias(f"signal_0_mean")
)
shape: (10, 3)
┌────────┬─────────────────────┬─────────────────────────────────┐
│ symbol ┆ ts                  ┆ signal_0_mean                   │
│ ---    ┆ ---                 ┆ ---                             │
│ str    ┆ datetime[μs]        ┆ list[f64]                       │
╞════════╪═════════════════════╪═════════════════════════════════╡
│ A      ┆ 2022-02-14 09:20:00 ┆ [-1.704301]                     │
│ A      ┆ 2022-02-14 09:50:00 ┆ [-1.704301, 1.037448]           │
│ A      ┆ 2022-02-14 10:20:00 ┆ [-1.704301, 1.037448, 1.93381]  │
│ A      ┆ 2022-02-14 10:50:00 ┆ [-1.704301, 1.037448, … 2.0732… │
│ A      ┆ 2022-02-14 11:20:00 ┆ [-1.704301, 1.037448, … 2.0888… │
│ AA     ┆ 2022-02-14 09:20:00 ┆ [-1.181743]                     │
│ AA     ┆ 2022-02-14 09:50:00 ┆ [-1.181743, 0.774887]           │
│ AA     ┆ 2022-02-14 10:20:00 ┆ [-1.181743, 0.774887, 1.480203… │
│ AA     ┆ 2022-02-14 10:50:00 ┆ [-1.181743, 0.774887, … 1.6235… │
│ AA     ┆ 2022-02-14 11:20:00 ┆ [-1.181743, 0.774887, … 1.7414… │
└────────┴─────────────────────┴─────────────────────────────────┘

If I wanted to do this in pandas, I would write:

portfolios.to_pandas().set_index(["ts", "symbol"]).groupby(level=1)["signal_0"].transform(
    lambda x: x.ewm(halflife=10).mean()
)

Which would yield:

ts                   symbol
2022-02-14 09:20:00  A        -1.704301
                     AA       -1.181743
2022-02-14 09:50:00  A        -0.284550
                     AA       -0.168547
2022-02-14 10:20:00  A         0.507021
                     AA        0.419785
2022-02-14 10:50:00  A         0.940226
                     AA        0.752741
2022-02-14 11:20:00  A         1.202843
                     AA        0.978820
Name: signal_0, dtype: float64

Solution

  • You were close. Since ewm_mean produces an estimate for each observation in each window, you simply need to specify that you want the last calculated value in each rolling window.

    (
        portfolios
        .rolling("ts", group_by="symbol", period="1d")
        .agg(
            pl.col("signal_0").ewm_mean(half_life=10).last().alias(f"signal_0_mean")
        )
        .sort('ts', 'symbol')
    )
    
    shape: (10, 3)
    ┌────────┬─────────────────────┬───────────────┐
    │ symbol ┆ ts                  ┆ signal_0_mean │
    │ ---    ┆ ---                 ┆ ---           │
    │ str    ┆ datetime[μs]        ┆ f64           │
    ╞════════╪═════════════════════╪═══════════════╡
    │ A      ┆ 2022-02-14 09:20:00 ┆ -1.704301     │
    │ AA     ┆ 2022-02-14 09:20:00 ┆ -1.181743     │
    │ A      ┆ 2022-02-14 09:50:00 ┆ -0.28455      │
    │ AA     ┆ 2022-02-14 09:50:00 ┆ -0.168547     │
    │ A      ┆ 2022-02-14 10:20:00 ┆ 0.507021      │
    │ AA     ┆ 2022-02-14 10:20:00 ┆ 0.419785      │
    │ A      ┆ 2022-02-14 10:50:00 ┆ 0.940226      │
    │ AA     ┆ 2022-02-14 10:50:00 ┆ 0.752741      │
    │ A      ┆ 2022-02-14 11:20:00 ┆ 1.202844      │
    │ AA     ┆ 2022-02-14 11:20:00 ┆ 0.97882       │
    └────────┴─────────────────────┴───────────────┘