Search code examples
python-polars

Count consecutive True (or 1) values in a Boolean (or numeric) column with Polars?


I am hoping to count consecutive values in a column, preferably using Polars expressions.

import polars
df = pl.DataFrame(
   {"values": [True,True,True,False,False,True,False,False,True,True]}
)

With the example data frame above, I would like to count the number of consecutive True values.

Below is example output using R's Data.Table package.

library(data.table)
dt <- data.table(value = c(T,T,T,F,F,T,F,F,T,T))
dt[, value2 := fifelse((1:.N) == .N & value == 1, .N, NA_integer_), by = rleid(value)]
dt
value value2
TRUE NA
TRUE NA
TRUE 3
FALSE NA
FALSE NA
TRUE 1
FALSE NA
FALSE NA
TRUE NA
TRUE 2

Any ideas who this would be done efficiently using Polars?

[EDIT with a new approach]

I got it working with the code below, but hoping there is a more efficient way. Anyone know the default struct/dictionary field names from value_counts?

(
    df.lazy()
    .with_row_count()
    .with_column(
        pl.when(pl.col("value") == False).then(
            pl.col("row_nr")
            
        ).fill_null(
            strategy = "forward"
        ).alias("id_consecutive_Trues")
    )
    .with_column(
        pl.col("id_consecutive_Trues").value_counts(sort = True)
    )
    .with_column(
        (
            pl.col("id_consecutive_Trues").arr.eval(
                pl.element().struct().rename_fields(["value", "count"]).struct.field("count")
            ).arr.max()
            - pl.lit(1)
        ).alias("max_consecutive_true_values")
    )
    .collect()
)

Solution

  • One possible definition of the problem is:

    • On the last row of each true group, give me the group length.
    df.with_columns(
       pl.when(pl.col("values") & pl.col("values").is_last_distinct())
         .then(pl.len())
         .over(pl.col("values").rle_id())
    )
    
    shape: (10, 2)
    ┌────────┬───────┐
    │ values ┆ count │
    │ ---    ┆ ---   │
    │ bool   ┆ u32   │
    ╞════════╪═══════╡
    │ true   ┆ null  │
    │ true   ┆ null  │
    │ true   ┆ 3     │
    │ false  ┆ null  │
    │ false  ┆ null  │
    │ true   ┆ 1     │
    │ false  ┆ null  │
    │ false  ┆ null  │
    │ true   ┆ null  │
    │ true   ┆ 2     │
    └────────┴───────┘
    

    .rle_id() gives us "group ids" for the consecutive values.

    df.with_columns(group = pl.col("values").rle_id())
    
    shape: (10, 2)
    ┌────────┬───────┐
    │ values ┆ group │
    │ ---    ┆ ---   │
    │ bool   ┆ u32   │
    ╞════════╪═══════╡
    │ true   ┆ 0     │
    │ true   ┆ 0     │
    │ true   ┆ 0     │
    │ false  ┆ 1     │
    │ false  ┆ 1     │
    │ true   ┆ 2     │
    │ false  ┆ 3     │
    │ false  ┆ 3     │
    │ true   ┆ 4     │
    │ true   ┆ 4     │
    └────────┴───────┘
    

    .is_last_distinct() with the .over() allows us to detect the last row of each group.

    pl.len() with .over() gives us the number of rows in the group.