Search code examples
pythonpython-polars

Apply multiple window sizes to rolling aggregation functions in polars dataframe


In a number of aggregation function, such as rolling_mean, rolling_max, rolling_min, etc, the input argument window_size is supposed to be of type int

I am wondering how to efficiently compute results when having a list of window_size.

Consider the following dataframe:

import polars as pl

pl.Config(tbl_rows=-1)

df = pl.DataFrame(
    {
        "symbol": ["A", "A", "A", "A", "A", "B", "B", "B", "B"],
        "price": [100, 110, 105, 103, 107, 200, 190, 180, 185],
    }
)

shape: (9, 2)
┌────────┬───────┐
│ symbol ┆ price │
│ ---    ┆ ---   │
│ str    ┆ i64   │
╞════════╪═══════╡
│ A      ┆ 100   │
│ A      ┆ 110   │
│ A      ┆ 105   │
│ A      ┆ 103   │
│ A      ┆ 107   │
│ B      ┆ 200   │
│ B      ┆ 190   │
│ B      ┆ 180   │
│ B      ┆ 185   │
└────────┴───────┘

Let's say I have a list with n elements, such as periods = [2, 3]. I am looking for a solution to compute the rolling means for all periods grouped by symbol in parallel. Speed and memory efficiency is of the essence.

The result should be a tidy/long dataframe like this:

shape: (18, 4)
┌────────┬───────┬─────────────┬──────────────┐
│ symbol ┆ price ┆ mean_period ┆ rolling_mean │
│ ---    ┆ ---   ┆ ---         ┆ ---          │
│ str    ┆ i64   ┆ u8          ┆ f64          │
╞════════╪═══════╪═════════════╪══════════════╡
│ A      ┆ 100   ┆ 2           ┆ null         │
│ A      ┆ 110   ┆ 2           ┆ 105.0        │
│ A      ┆ 105   ┆ 2           ┆ 107.5        │
│ A      ┆ 103   ┆ 2           ┆ 104.0        │
│ A      ┆ 107   ┆ 2           ┆ 105.0        │
│ B      ┆ 200   ┆ 2           ┆ null         │
│ B      ┆ 190   ┆ 2           ┆ 195.0        │
│ B      ┆ 180   ┆ 2           ┆ 185.0        │
│ B      ┆ 185   ┆ 2           ┆ 182.5        │
│ A      ┆ 100   ┆ 3           ┆ null         │
│ A      ┆ 110   ┆ 3           ┆ null         │
│ A      ┆ 105   ┆ 3           ┆ 105.0        │
│ A      ┆ 103   ┆ 3           ┆ 106.0        │
│ A      ┆ 107   ┆ 3           ┆ 105.0        │
│ B      ┆ 200   ┆ 3           ┆ null         │
│ B      ┆ 190   ┆ 3           ┆ null         │
│ B      ┆ 180   ┆ 3           ┆ 190.0        │
│ B      ┆ 185   ┆ 3           ┆ 185.0        │
└────────┴───────┴─────────────┴──────────────┘

Solution

  • You can use comprehension to generate a DataFrame for each value in periods list and then concat() DataFrames into single long DataFrame:

    periods = [2, 3]
    
    pl.concat(
        df.with_columns(
            mean_period = pl.lit(p),
            rolling_mean = pl.col.price.rolling_mean(p).over("symbol")
        )
        for p in periods
    )
    
    ┌────────┬───────┬─────────────┬──────────────┐
    │ symbol ┆ price ┆ mean_period ┆ rolling_mean │
    │ ---    ┆ ---   ┆ ---         ┆ ---          │
    │ str    ┆ i64   ┆ i32         ┆ f64          │
    ╞════════╪═══════╪═════════════╪══════════════╡
    │ A      ┆ 100   ┆ 2           ┆ null         │
    │ A      ┆ 110   ┆ 2           ┆ 105.0        │
    │ A      ┆ 105   ┆ 2           ┆ 107.5        │
    │ A      ┆ 103   ┆ 2           ┆ 104.0        │
    │ A      ┆ 107   ┆ 2           ┆ 105.0        │
    │ B      ┆ 200   ┆ 2           ┆ null         │
    │ B      ┆ 190   ┆ 2           ┆ 195.0        │
    │ B      ┆ 180   ┆ 2           ┆ 185.0        │
    │ B      ┆ 185   ┆ 2           ┆ 182.5        │
    │ A      ┆ 100   ┆ 3           ┆ null         │
    │ A      ┆ 110   ┆ 3           ┆ null         │
    │ A      ┆ 105   ┆ 3           ┆ 105.0        │
    │ A      ┆ 103   ┆ 3           ┆ 106.0        │
    │ A      ┆ 107   ┆ 3           ┆ 105.0        │
    │ B      ┆ 200   ┆ 3           ┆ null         │
    │ B      ┆ 190   ┆ 3           ┆ null         │
    │ B      ┆ 180   ┆ 3           ┆ 190.0        │
    │ B      ┆ 185   ┆ 3           ┆ 185.0        │
    └────────┴───────┴─────────────┴──────────────┘