Search code examples
pythonpython-polars

Use the rolling function of polars to get a list of all values in the rolling windows


I'd like to use the rolling function to get a list of all values in the rolling windows.

I tried it with follwing code snippet:

import polars as pl

df = pl.DataFrame(
    {
        "A": [1, 2, 9, 2, 13],
    }
)

df.select(
    pl.col("A").rolling_map(lambda s: s, 3),
)

this outputs

┌──────┐
│ A    │
│ ---  │
│ i64  │
╞══════╡
│ null │
│ null │
│ 1    │
│ 2    │
│ 9    │
└──────┘

but what I need is:

shape: (5, 1)
┌─────────────────┐
│ A               │
│ ---             │
│ list[i64]       │
╞═════════════════╡
│ [null, null, 1] │
│ [null, 1, 2]    │
│ [1, 2, 9]       │
│ [2, 9, 2]       │
│ [9, 2, 13]      │
└─────────────────┘

Does anyone got a idea how to do this in polars in a easy way?


Solution

  • You could create lagged columns and collect them into a list.

    (df
        .with_columns(pl.col("A").shift(i).alias(f"A_lag_{i}") for i in range(3))
        .select(
            pl.concat_list(reversed([f"A_lag_{i}" for i in range(3)])).alias("A_rolling")
    ))
    

    Outputs:

    shape: (5, 1)
    ┌─────────────────┐
    │ A_rolling       │
    │ ---             │
    │ list[i64]       │
    ╞═════════════════╡
    │ [null, null, 1] │
    │ [null, 1, 2]    │
    │ [1, 2, 9]       │
    │ [2, 9, 2]       │
    │ [9, 2, 13]      │
    └─────────────────┘
    
    

    Let's break it down:

    reshape(-1, 1) converts "A" into a list.

    pl.col(..).shift(i) for i in range(3)] creates the new lagged columns.

    This results in this intermediate DataFrame:

    shape: (5, 4)
    ┌─────┬────────────┬────────────┬────────────┐
    │ A   ┆ A_lag_0    ┆ A_lag_1    ┆ A_lag_2    │
    │ --- ┆ ---        ┆ ---        ┆ ---        │
    │ f64 ┆ list [f64] ┆ list [f64] ┆ list [f64] │
    ╞═════╪════════════╪════════════╪════════════╡
    │ 1   ┆ [1]        ┆ [null]     ┆ [null]     │
    │ 2   ┆ [2]        ┆ [1]        ┆ [null]     │
    │ 9   ┆ [9]        ┆ [2]        ┆ [1]        │
    │ 2   ┆ [2]        ┆ [9]        ┆ [2]        │
    │ 13  ┆ [13]       ┆ [2]        ┆ [9]        │
    └─────┴────────────┴────────────┴────────────┘
    
    

    Finally we concat them in reversed order and name the output "A_rolling":

    pl.concat_list(reversed([f"A_lag_{i}" for i in range(3)])).alias("A_rolling")