Search code examples
pythonpython-polars

Create a conditional cumulative sum in Polars


Example dataframe:

testDf = pl.DataFrame({
    "Date1": ["2024-04-01", "2024-04-06", "2024-04-07", "2024-04-10", "2024-04-11"],
    "Date2": ["2024-04-04", "2024-04-07", "2024-04-09", "2024-04-10", "2024-04-15"],
    "Date3": ["2024-04-07", "2024-04-08", "2024-04-10", "2024-05-15", "2024-04-21"],
    'Value': [10, 15, -20, 5, 30]
}).with_columns(pl.col('Date1').cast(pl.Date),
                pl.col('Date2').cast(pl.Date),
                pl.col('Date3').cast(pl.Date)
                )
shape: (5, 4)
┌────────────┬────────────┬────────────┬───────┐
│ Date1      ┆ Date2      ┆ Date3      ┆ Value │
│ ---        ┆ ---        ┆ ---        ┆ ---   │
│ date       ┆ date       ┆ date       ┆ i64   │
╞════════════╪════════════╪════════════╪═══════╡
│ 2024-04-01 ┆ 2024-04-04 ┆ 2024-04-07 ┆ 10    │
│ 2024-04-06 ┆ 2024-04-07 ┆ 2024-04-08 ┆ 15    │
│ 2024-04-07 ┆ 2024-04-09 ┆ 2024-04-10 ┆ -20   │
│ 2024-04-10 ┆ 2024-04-10 ┆ 2024-05-15 ┆ 5     │
│ 2024-04-11 ┆ 2024-04-15 ┆ 2024-04-21 ┆ 30    │
└────────────┴────────────┴────────────┴───────┘

What I would like to do is create a dataframe in which for each 'Date1' I would have a column of the cumulative sum of 'Value' where 'Date1' >= 'Date2' and 'Date1' <= 'Date3'. So when 'Date1' ='2024-04-10' the sum should read -15, since the first 2 rows 'Date3' <= '2024-04-10' and the last row has 'Date2' = '2024-04-15' >= '2024-04-10'.

I tried this:

testDf.group_by(pl.col('Date1'))\
    .agg(pl.col('Value')\
        .filter((pl.col('Date1') >= pl.col('Date2')) & (pl.col('Date1') <= pl.col('Date3')))\
            .sum())
        shape: (5, 2)
┌────────────┬───────┐
│ Date1      ┆ Value │
│ ---        ┆ ---   │
│ date       ┆ i64   │
╞════════════╪═══════╡
│ 2024-04-11 ┆ 0     │
│ 2024-04-06 ┆ 0     │
│ 2024-04-07 ┆ 0     │
│ 2024-04-10 ┆ 5     │
│ 2024-04-01 ┆ 0     │
└────────────┴───────┘

But my desired result is this:

shape: (5, 2)
┌────────────┬─────┐
│ Date1      ┆ Sum │
│ ---        ┆ --- │
│ date       ┆ i64 │
╞════════════╪═════╡
│ 2024-04-01 ┆ 0   │
│ 2024-04-06 ┆ 10  │
│ 2024-04-07 ┆ 25  │
│ 2024-04-10 ┆ -15 │
│ 2024-04-11 ┆ 5   │
└────────────┴─────┘

Solution

  • I'll need to think about it a bit more to understand whether a solution relying purely on polars' native expression API is possible. However, here is a preliminary solution relying on the discouraged pl.Expr.map_elements.

    (
        testDf
        .with_columns(
            pl.col("Date1")
            .map_elements(
                lambda x: \
                    (
                        testDf
                        .filter(
                            pl.col("Date2") <= x,
                            pl.col("Date3") >= x,
                        )
                        .get_column("Value")
                        .sum()
                    ),
                return_dtype=pl.Int64
            )
            .alias("Sum")
        )
    )
    
    shape: (5, 5)
    ┌────────────┬────────────┬────────────┬───────┬─────┐
    │ Date1      ┆ Date2      ┆ Date3      ┆ Value ┆ Sum │
    │ ---        ┆ ---        ┆ ---        ┆ ---   ┆ --- │
    │ date       ┆ date       ┆ date       ┆ i64   ┆ i64 │
    ╞════════════╪════════════╪════════════╪═══════╪═════╡
    │ 2024-04-01 ┆ 2024-04-04 ┆ 2024-04-07 ┆ 10    ┆ 0   │
    │ 2024-04-06 ┆ 2024-04-07 ┆ 2024-04-08 ┆ 15    ┆ 10  │
    │ 2024-04-07 ┆ 2024-04-09 ┆ 2024-04-10 ┆ -20   ┆ 25  │
    │ 2024-04-10 ┆ 2024-04-10 ┆ 2024-05-15 ┆ 5     ┆ -15 │
    │ 2024-04-11 ┆ 2024-04-15 ┆ 2024-04-21 ┆ 30    ┆ 5   │
    └────────────┴────────────┴────────────┴───────┴─────┘