Search code examples
pythonpython-polars

Calculate Windowed Event Chains


Given a Polars DataFrame

data = pl.DataFrame({"user_id": [1, 1, 1, 1, 1, 2, 2, 2, 2], "event": [False, True, True, False, True, True, True, False, False]

I wish to calculate a column event_chain which counts the streak of times where a user has an event, where in any of the previous 4 rows they also had an event. Every time a new event happens, when the user already has a streak active, the streak counter is incremented, it should be then set to zero if they don't have another event for another 4 rows

user_id event event_chain reason for value
1 False 0 no events yet
1 True 0 No events in last 4 rows (not inclusive of current row)
1 True 1 event this row, and 1 event in last 4 rows
1 False 1 Does not reset to 0 as there is an event within the next 4 rows
1 True 2 event this row and event last 4 rows, increment the streak
2 True 0 No previous events
2 True 1 Event this row and in last 4 rows for user
2 False 0 No event this row, and no events in next 4 rows for user, resets to 0
2 False 0

I have working code as follows to do this, but I think there should be a cleaner way to do it

        data.with_columns(
         rows_since_last_event=pl.int_range(pl.len()).over("user_id")
          - pl.when("event").then(pl.int_range(pl.len())).forward_fill()
          .over("user_id"),
          rows_till_next_event=pl.when("event").then(pl.int_range(pl.len()))
          .backward_fill().over("user_id") - pl.int_range(pl.len()).over("athlete_id")
         )
        .with_columns(
            chain_event=pl.when(
                pl.col("event")
                .fill_null(0)
                .rolling_sum(window_size=4, min_periods=1)
                .over("user_id")
                - pl.col("event").fill_null(0)
                > 0
            )
            .then(1)
            .otherwise(0)
        )
        .with_columns(
            chain_event_change=pl.when(
                pl.col("chain_event").eq(1),
                pl.col("chain_event").shift().eq(0),
                pl.col("rows_since_last_event").fill_null(5) > 3,
            )
            .then(1)
            .when(
                pl.col("congested_event").eq(0),
                pl.col("congested_event").shift().eq(1),
                pl.col("rows_till_next_event").fill_null(5) > 3,
            )
            .then(1)
            .otherwise(0)
        )
        .with_columns(
            chain_event_identifier=pl.col("chain_event_change")
            .cum_sum()
            .over("user_id")
        )
        .with_columns(
            event_chain=pl.col("chain_event")
            .cum_sum()
            .over("user_id", "chain_event_identifier")
        )
    )

Solution

  • updated version I looked at @jqurious answer and I think you can make it even more concise

    • .sum_horizontal() to precalculate counter while checking previous N rows. We only need sum for previous rows, for next rows we just need to know if they exist, so max is enough.
    • Also note that we use window of size 5 (including current row) instead so we don't need special case for 'starting' event.
    (
        data
        .with_columns(
            chain_event =
               pl.sum_horizontal(pl.col.event.shift(i) for i in range(5))
                 .over('user_id'),
            next =
               pl.max_horizontal(pl.col.event.shift(-i) for i in range(1,5))
                 .over('user_id').fill_null(False)
        ).with_columns(
            pl
            .when(event = False, next = False).then(0)
            .when(event = False, chain_event = 0).then(0)
            .otherwise(pl.col.chain_event - 1)
            .alias('chain_event')
            # or even shorter but a bit more cryptic
            # pl
            # .when(event = False, next = False).then(0)
            # .otherwise(pl.col.chain_event - pl.col.event)
            # .alias('chain_event')
        )
    )
    
    ┌─────────┬───────┬──────┬───────┬─────────────┐
    │ user_id ┆ event ┆ prev ┆ next  ┆ chain_event │
    │ ---     ┆ ---   ┆ ---  ┆ ---   ┆ ---         │
    │ i64     ┆ bool  ┆ u32  ┆ bool  ┆ i64         │
    ╞═════════╪═══════╪══════╪═══════╪═════════════╡
    │ 1       ┆ false ┆ 0    ┆ true  ┆ 0           │
    │ 1       ┆ true  ┆ 1    ┆ true  ┆ 0           │
    │ 1       ┆ true  ┆ 2    ┆ true  ┆ 1           │
    │ 1       ┆ false ┆ 2    ┆ true  ┆ 1           │
    │ 1       ┆ true  ┆ 3    ┆ false ┆ 2           │
    │ 2       ┆ true  ┆ 1    ┆ true  ┆ 0           │
    │ 2       ┆ true  ┆ 2    ┆ false ┆ 1           │
    │ 2       ┆ false ┆ 2    ┆ false ┆ 0           │
    │ 2       ┆ false ┆ 2    ┆ false ┆ 0           │
    └─────────┴───────┴──────┴───────┴─────────────┘
    

    previous version

    Basically, the most important part here is that we treat row as being within chain if either one of the conditions is met:

    • There's event within current row.
    • There's event within previous 4 rows (otherwise we've restarted the counter already) AND there's event within next 4 rows (otherwise we're going to reset the counter).
    (
        data
        .with_columns(
            pl.max_horizontal(pl.col("event").shift(i + 1).over('user_id') for i in range(4)).alias("max_lag").fill_null(False),
            pl.max_horizontal(pl.col("event").shift(-i - 1).over('user_id') for i in range(4)).alias("max_lead").fill_null(False)
        ).with_columns(
            event_chain = (pl.col("max_lag") & pl.col("max_lead")) | pl.col('event')
        ).select(
            pl.col('user_id','event'),
            pl.when(pl.col('event_chain'))
            .then(
                pl.col('event').cum_sum().over('user_id', pl.col('event_chain').rle_id().over('user_id')) - 1
            ).otherwise(0)
            .alias('event_chain')
        )
    )
    
    ┌─────────┬───────┬─────────────┐
    │ user_id ┆ event ┆ event_chain │
    │ ---     ┆ ---   ┆ ---         │
    │ i64     ┆ bool  ┆ i64         │
    ╞═════════╪═══════╪═════════════╡
    │ 1       ┆ false ┆ 0           │
    │ 1       ┆ true  ┆ 0           │
    │ 1       ┆ true  ┆ 1           │
    │ 1       ┆ false ┆ 1           │
    │ 1       ┆ true  ┆ 2           │
    │ 2       ┆ true  ┆ 0           │
    │ 2       ┆ true  ┆ 1           │
    │ 2       ┆ false ┆ 0           │
    │ 2       ┆ false ┆ 0           │
    └─────────┴───────┴─────────────┘
    

    Alternatively

    • .rolling_max() to calculate if there's event within previous 4 rows
    • same with .reverse() to calculate if there's event within next 4 rows
    (
        data
        .with_columns(
            (pl.col('event').cast(pl.Int32).shift(1).rolling_max(4, min_periods=0)).over('user_id').fill_null(0).alias('max_lag'),
            (pl.col('event').reverse().cast(pl.Int32).shift(1).rolling_max(4, min_periods=0).reverse()).over('user_id').fill_null(0).alias('max_lead')
        ).with_columns(
            event_chain = ((pl.col("max_lag") == 1) & (pl.col("max_lead") == 1)) | pl.col('event')
        ).select(
            pl.col('user_id','event'),
            pl.when(pl.col('event_chain'))
            .then(
                pl.col('event').cum_sum().over('user_id', pl.col('event_chain').rle_id().over('user_id')) - 1
            ).otherwise(0)
            .alias('event_chain')
        )
    )