Search code examples
pythonpython-polars

How to get the aggregate of the records up to the current record using polars?


Given a dataset with records of an event where the event can happen multiple times for the same ID I want to find the aggregate of the previous records of that ID. Let's say I have the following table:

id datetime value
123 20230101T00:00 2
123 20230101T01:00 5
123 20230101T03:00 7
123 20230101T04:00 1
456 20230201T04:00 1
456 20230201T07:00 1
456 20230205T04:00 1

I want to create a new column "agg" that adds the previous values of "value" found for that same record to get the following table:

id datetime value agg
123 20230101T00:00 2 0
123 20230101T01:00 5 2
123 20230101T03:00 7 7
123 20230101T04:00 1 14
456 20230201T04:00 1 0
456 20230201T07:00 1 1
456 20230205T04:00 1 2

Polars documentation says there is a window function but it is not clear how to collect just the previous values of the current record. I know it is possible to do this with PySpark using:

window = Window.partitionBy('id').orderBy('datetime').rowsBetween(Window.unboundedPreceding, -1)

(
    df_pyspark
    .withColumn('agg', f.sum('value').over(window).cast('int'))
    .fillna(0, subset=['agg'])
)

EDIT: Apparently there is no trivial way to execute this code using polars as it is being discussed in an open issue 8976 in the polars' repository.


EDIT: Fixed the expected table after aplying the window function


Solution

  • Your description sounds like .shift().cum_sum().over()

    df = pl.from_repr("""
    ┌─────┬─────────────────────┬───────┐
    │ id  ┆ datetime            ┆ value │
    │ --- ┆ ---                 ┆ ---   │
    │ i64 ┆ datetime[ns]        ┆ i64   │
    ╞═════╪═════════════════════╪═══════╡
    │ 123 ┆ 2023-01-01 00:00:00 ┆ 2     │
    │ 123 ┆ 2023-01-01 01:00:00 ┆ 5     │
    │ 123 ┆ 2023-01-01 03:00:00 ┆ 7     │
    │ 123 ┆ 2023-01-01 04:00:00 ┆ 1     │
    │ 456 ┆ 2023-02-01 04:00:00 ┆ 1     │
    │ 456 ┆ 2023-02-01 07:00:00 ┆ 1     │
    │ 456 ┆ 2023-02-05 04:00:00 ┆ 1     │
    └─────┴─────────────────────┴───────┘
    """)
    
    df.with_columns(agg = pl.col('value').shift().cum_sum().over('id'))
    
    shape: (7, 4)
    ┌─────┬─────────────────────┬───────┬──────┐
    │ id  ┆ datetime            ┆ value ┆ agg  │
    │ --- ┆ ---                 ┆ ---   ┆ ---  │
    │ i64 ┆ datetime[ns]        ┆ i64   ┆ i64  │
    ╞═════╪═════════════════════╪═══════╪══════╡
    │ 123 ┆ 2023-01-01 00:00:00 ┆ 2     ┆ null │
    │ 123 ┆ 2023-01-01 01:00:00 ┆ 5     ┆ 2    │
    │ 123 ┆ 2023-01-01 03:00:00 ┆ 7     ┆ 7    │
    │ 123 ┆ 2023-01-01 04:00:00 ┆ 1     ┆ 14   │
    │ 456 ┆ 2023-02-01 04:00:00 ┆ 1     ┆ null │
    │ 456 ┆ 2023-02-01 07:00:00 ┆ 1     ┆ 1    │
    │ 456 ┆ 2023-02-05 04:00:00 ┆ 1     ┆ 2    │
    └─────┴─────────────────────┴───────┴──────┘