Search code examples
pythonpython-polars

Calculate group mean for an int_range column in Polars dataframe


Update: This has been resolved in Polars. The code now runs without error.


I have the following code in polars:

import datetime
import polars as pl

df = pl.DataFrame(
    {
        "id": [1, 2, 1, 2, 1, 2, 3],
        "date": [
            datetime.date(2022, 1, 1),
            datetime.date(2022, 1, 1),
            datetime.date(2022, 1, 11),
            datetime.date(2022, 1, 11),
            datetime.date(2022, 2, 1),
            datetime.date(2022, 2, 1),
            datetime.date(2022, 2, 1),
        ],
        "value": [1, 2, 3, None, 5, 6, None],
    }
)

(df.group_by_dynamic("date", group_by="id", every="1mo", period="1mo", closed="both")
   .agg(
       pl.int_range(1, pl.len() + 1) 
       - pl.int_range(1, pl.len() + 1).filter(pl.col("value").is_not_null()).mean(),
   )
)

But, when I run it, I got the following error which I don't quite understand.

pyo3_runtime.PanicException: index out of bounds: the len is 1 but the index is 1

The behavior I want to achieve is: for each group, create a natural sequence from 1 to the number of rows in that group, and subtract from it the average over non-null in the "value" column in that group. (return null if all "value" in that group are null).

To be more specific, the result I want is

shape: (5, 3)
┌─────┬────────────┬──────────────────┐
│ id  ┆ date       ┆ arange           │
│ --- ┆ ---        ┆ ---              │
│ i64 ┆ date       ┆ list[f64]        │
╞═════╪════════════╪══════════════════╡
│ 1   ┆ 2022-01-01 ┆ [-1.0, 0.0, 1.0] │
│ 1   ┆ 2022-02-01 ┆ [0.0]            │
│ 2   ┆ 2022-01-01 ┆ [-1.0, 2.0, 1.0] │
│ 2   ┆ 2022-02-01 ┆ [0.0]            │
│ 3   ┆ 2022-02-01 ┆ [null]           │
└─────┴────────────┴──────────────────┘

How can I achieve this?


Solution

  • As a workaround, perhaps you could explode first, then implement the logic.

    (
       df.group_by_dynamic(index_column="date", group_by="id", every="1mo", period="1mo", closed="both")
         .agg(pl.exclude("date"))
         .with_row_index("group")
         .explode("value")
         .with_columns(
             (pl.int_range(1, pl.len() + 1) 
             - pl.int_range(1, pl.len() + 1).filter(pl.col("value").is_not_null()).mean())
             .over("group")
         )
    )
    
    shape: (9, 5)
    ┌───────┬─────┬────────────┬───────┬─────────┐
    │ group ┆ id  ┆ date       ┆ value ┆ literal │
    │ ---   ┆ --- ┆ ---        ┆ ---   ┆ ---     │
    │ u32   ┆ i64 ┆ date       ┆ i64   ┆ f64     │
    ╞═══════╪═════╪════════════╪═══════╪═════════╡
    │ 0     ┆ 1   ┆ 2022-01-01 ┆ 1     ┆ -1.0    │
    │ 0     ┆ 1   ┆ 2022-01-01 ┆ 3     ┆ 0.0     │
    │ 0     ┆ 1   ┆ 2022-01-01 ┆ 5     ┆ 1.0     │
    │ 1     ┆ 1   ┆ 2022-02-01 ┆ 5     ┆ 0.0     │
    │ 2     ┆ 2   ┆ 2022-01-01 ┆ 2     ┆ -1.0    │
    │ 2     ┆ 2   ┆ 2022-01-01 ┆ null  ┆ 0.0     │
    │ 2     ┆ 2   ┆ 2022-01-01 ┆ 6     ┆ 1.0     │
    │ 3     ┆ 2   ┆ 2022-02-01 ┆ 6     ┆ 0.0     │
    │ 4     ┆ 3   ┆ 2022-02-01 ┆ null  ┆ null    │
    └───────┴─────┴────────────┴───────┴─────────┘
    

    You could use the group column to piece them back together afterwards.