Given a Polars DataFrame
data = pl.DataFrame({"user_id": [1, 1, 1, 1, 1, 2, 2, 2, 2], "event": [False, True, True, False, True, True, True, False, False]
I wish to calculate a column event_chain
which counts the streak of times where a user has an event, where in any of the previous 4 rows they also had an event. Every time a new event happens, when the user already has a streak active, the streak counter is incremented, it should be then set to zero if they don't have another event for another 4 rows
user_id | event | event_chain | reason for value |
---|---|---|---|
1 | False | 0 | no events yet |
1 | True | 0 | No events in last 4 rows (not inclusive of current row) |
1 | True | 1 | event this row, and 1 event in last 4 rows |
1 | False | 1 | Does not reset to 0 as there is an event within the next 4 rows |
1 | True | 2 | event this row and event last 4 rows, increment the streak |
2 | True | 0 | No previous events |
2 | True | 1 | Event this row and in last 4 rows for user |
2 | False | 0 | No event this row, and no events in next 4 rows for user, resets to 0 |
2 | False | 0 |
I have working code as follows to do this, but I think there should be a cleaner way to do it
data.with_columns(
rows_since_last_event=pl.int_range(pl.len()).over("user_id")
- pl.when("event").then(pl.int_range(pl.len())).forward_fill()
.over("user_id"),
rows_till_next_event=pl.when("event").then(pl.int_range(pl.len()))
.backward_fill().over("user_id") - pl.int_range(pl.len()).over("athlete_id")
)
.with_columns(
chain_event=pl.when(
pl.col("event")
.fill_null(0)
.rolling_sum(window_size=4, min_periods=1)
.over("user_id")
- pl.col("event").fill_null(0)
> 0
)
.then(1)
.otherwise(0)
)
.with_columns(
chain_event_change=pl.when(
pl.col("chain_event").eq(1),
pl.col("chain_event").shift().eq(0),
pl.col("rows_since_last_event").fill_null(5) > 3,
)
.then(1)
.when(
pl.col("congested_event").eq(0),
pl.col("congested_event").shift().eq(1),
pl.col("rows_till_next_event").fill_null(5) > 3,
)
.then(1)
.otherwise(0)
)
.with_columns(
chain_event_identifier=pl.col("chain_event_change")
.cum_sum()
.over("user_id")
)
.with_columns(
event_chain=pl.col("chain_event")
.cum_sum()
.over("user_id", "chain_event_identifier")
)
)
updated version I looked at @jqurious answer and I think you can make it even more concise
.sum_horizontal()
to precalculate counter while checking previous N rows. We only need sum
for previous rows, for next rows we just need to know if they exist, so max
is enough.(
data
.with_columns(
chain_event =
pl.sum_horizontal(pl.col.event.shift(i) for i in range(5))
.over('user_id'),
next =
pl.max_horizontal(pl.col.event.shift(-i) for i in range(1,5))
.over('user_id').fill_null(False)
).with_columns(
pl
.when(event = False, next = False).then(0)
.when(event = False, chain_event = 0).then(0)
.otherwise(pl.col.chain_event - 1)
.alias('chain_event')
# or even shorter but a bit more cryptic
# pl
# .when(event = False, next = False).then(0)
# .otherwise(pl.col.chain_event - pl.col.event)
# .alias('chain_event')
)
)
┌─────────┬───────┬──────┬───────┬─────────────┐
│ user_id ┆ event ┆ prev ┆ next ┆ chain_event │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ bool ┆ u32 ┆ bool ┆ i64 │
╞═════════╪═══════╪══════╪═══════╪═════════════╡
│ 1 ┆ false ┆ 0 ┆ true ┆ 0 │
│ 1 ┆ true ┆ 1 ┆ true ┆ 0 │
│ 1 ┆ true ┆ 2 ┆ true ┆ 1 │
│ 1 ┆ false ┆ 2 ┆ true ┆ 1 │
│ 1 ┆ true ┆ 3 ┆ false ┆ 2 │
│ 2 ┆ true ┆ 1 ┆ true ┆ 0 │
│ 2 ┆ true ┆ 2 ┆ false ┆ 1 │
│ 2 ┆ false ┆ 2 ┆ false ┆ 0 │
│ 2 ┆ false ┆ 2 ┆ false ┆ 0 │
└─────────┴───────┴──────┴───────┴─────────────┘
previous version
.shift()
to get 4 previous and 4 next rows..max_horizontal()
so we know if there's an event within these windows..rle_id()
to create continuous groups of events so we can restart the counter..cum_sum()
to increment counters..when().then().otherwise()
to only take into account groups with events.Basically, the most important part here is that we treat row as being within chain if either one of the conditions is met:
(
data
.with_columns(
pl.max_horizontal(pl.col("event").shift(i + 1).over('user_id') for i in range(4)).alias("max_lag").fill_null(False),
pl.max_horizontal(pl.col("event").shift(-i - 1).over('user_id') for i in range(4)).alias("max_lead").fill_null(False)
).with_columns(
event_chain = (pl.col("max_lag") & pl.col("max_lead")) | pl.col('event')
).select(
pl.col('user_id','event'),
pl.when(pl.col('event_chain'))
.then(
pl.col('event').cum_sum().over('user_id', pl.col('event_chain').rle_id().over('user_id')) - 1
).otherwise(0)
.alias('event_chain')
)
)
┌─────────┬───────┬─────────────┐
│ user_id ┆ event ┆ event_chain │
│ --- ┆ --- ┆ --- │
│ i64 ┆ bool ┆ i64 │
╞═════════╪═══════╪═════════════╡
│ 1 ┆ false ┆ 0 │
│ 1 ┆ true ┆ 0 │
│ 1 ┆ true ┆ 1 │
│ 1 ┆ false ┆ 1 │
│ 1 ┆ true ┆ 2 │
│ 2 ┆ true ┆ 0 │
│ 2 ┆ true ┆ 1 │
│ 2 ┆ false ┆ 0 │
│ 2 ┆ false ┆ 0 │
└─────────┴───────┴─────────────┘
Alternatively
.rolling_max()
to calculate if there's event within previous 4 rows.reverse()
to calculate if there's event within next 4 rows(
data
.with_columns(
(pl.col('event').cast(pl.Int32).shift(1).rolling_max(4, min_periods=0)).over('user_id').fill_null(0).alias('max_lag'),
(pl.col('event').reverse().cast(pl.Int32).shift(1).rolling_max(4, min_periods=0).reverse()).over('user_id').fill_null(0).alias('max_lead')
).with_columns(
event_chain = ((pl.col("max_lag") == 1) & (pl.col("max_lead") == 1)) | pl.col('event')
).select(
pl.col('user_id','event'),
pl.when(pl.col('event_chain'))
.then(
pl.col('event').cum_sum().over('user_id', pl.col('event_chain').rle_id().over('user_id')) - 1
).otherwise(0)
.alias('event_chain')
)
)