Search code examples
python-polars

How to do a rolling mean average grouped by other columns and specify a minimum period in Polars?


Let's say I have a dataset similar to this one except we have a Price for each hour of the day each day:

df = pl.from_repr("""
┌─────────┬─────────────────────┬─────────┬──────┐
│ Group   ┆ Time                ┆ Price   ┆ Hour │
│ ---     ┆ ---                 ┆ ---     ┆ ---  │
│ str     ┆ datetime[ns]        ┆ f64     ┆ i8   │
╞═════════╪═════════════════════╪═════════╪══════╡
│ Group1  ┆ 2021-01-03 00:00:00 ┆ 15.6328 ┆ 0    │
│ Group1  ┆ 2021-01-03 05:00:00 ┆ 17.1562 ┆ 5    │
│ Group1  ┆ 2021-01-03 08:00:00 ┆ 13.9062 ┆ 8    │
│ Group2  ┆ 2021-01-03 10:00:00 ┆ 18.5625 ┆ 10   │
│ Group2  ┆ 2021-01-03 00:00:00 ┆ 28.375  ┆ 0    │
│ Group2  ┆ 2021-01-03 13:00:00 ┆ 15.4219 ┆ 13   │
└─────────┴─────────────────────┴─────────┴──────┘
""")

In polars in python, how could I do the rolling mean of Price per Group per Hour for the last 30 days. For example for Group1 for Hour 0, I would take the past 30 rows of Group=Group1 and Hour=0 and then do a mean of that.

I implemented the following with rolling but I can`t specify a min_period in rolling:

df.rolling('Time', group_by=['Group', 'Hour'], period='30d').agg(
    pl.col('Price').mean().alias('PriceAvg')
)

Then there's the method rolling_mean which allows to specify a min_period, but using it inside rolling is much slower.

What is the best way of doing this?


Solution

  • you can do it with a window (over) function and the rolling_mean_by.

    Below is an example. I slightly changed your example to a ensure I was able to test the formula.

    df = pl.from_repr("""
    ┌────────┬─────────────────────┬───────┐
    │ Group  ┆ Time                ┆ Price │
    │ ---    ┆ ---                 ┆ ---   │
    │ str    ┆ datetime[μs]        ┆ i64   │
    ╞════════╪═════════════════════╪═══════╡
    │ group1 ┆ 2021-01-03 00:00:00 ┆ 1     │
    │ group1 ┆ 2021-01-02 00:00:00 ┆ 2     │
    │ group1 ┆ 2020-12-20 00:00:00 ┆ 3     │
    │ group1 ┆ 2020-06-01 00:00:00 ┆ 4     │
    │ group1 ┆ 2021-01-03 05:00:00 ┆ 5     │
    │ group2 ┆ 2021-01-01 00:00:00 ┆ 6     │
    └────────┴─────────────────────┴───────┘
    """).with_row_index()
    
    (
        df.sort('Group','Time').with_columns(
            avg_price_same_hour_30_days = 
                pl.col('Price')
                .rolling_mean_by(window_size = '30d', by='Time', closed='both')
                .over('Group', pl.col('Time').dt.hour()))
        .sort('index')
    )
    

    You can skip the last sort, I included it to make it easy to read the results

    shape: (6, 5)
    ┌───────┬────────┬─────────────────────┬───────┬─────────────────────────────┐
    │ index ┆ Group  ┆ Time                ┆ Price ┆ avg_price_same_hour_30_days │
    │ ---   ┆ ---    ┆ ---                 ┆ ---   ┆ ---                         │
    │ u32   ┆ str    ┆ datetime[μs]        ┆ i64   ┆ f64                         │
    ╞═══════╪════════╪═════════════════════╪═══════╪═════════════════════════════╡
    │ 0     ┆ group1 ┆ 2021-01-03 00:00:00 ┆ 1     ┆ 2.0                         │
    │ 1     ┆ group1 ┆ 2021-01-02 00:00:00 ┆ 2     ┆ 2.5                         │
    │ 2     ┆ group1 ┆ 2020-12-20 00:00:00 ┆ 3     ┆ 3.0                         │
    │ 3     ┆ group1 ┆ 2020-06-01 00:00:00 ┆ 4     ┆ 4.0                         │
    │ 4     ┆ group1 ┆ 2021-01-03 05:00:00 ┆ 5     ┆ 5.0                         │
    │ 5     ┆ group2 ┆ 2021-01-01 00:00:00 ┆ 6     ┆ 6.0                         │
    └───────┴────────┴─────────────────────┴───────┴─────────────────────────────┘
    

    A review of a couple of results:

    • For row 0, the 30 days period is from 2020-12-04 to 2021-01-03, and only for the same hour and the same group. That means the rows to be included are row 0, row 1 and row 2. The price result is (1 + 2 + 3) / 3 = 2

    • For row 1, the 30 days period is from 2020-12-03 to 2021-01-02, and only for the same hour and the same group. That means the rows to be included are row 1 and row 2. The price result is (2 + 3) / 2 = 2.5

    For the remaining rows, only the current row is included