I am hoping to count consecutive values in a column, preferably using Polars expressions.
import polars
df = pl.DataFrame(
{"values": [True,True,True,False,False,True,False,False,True,True]}
)
With the example data frame above, I would like to count the number of consecutive True values.
Below is example output using R's Data.Table package.
library(data.table)
dt <- data.table(value = c(T,T,T,F,F,T,F,F,T,T))
dt[, value2 := fifelse((1:.N) == .N & value == 1, .N, NA_integer_), by = rleid(value)]
dt
value | value2 |
---|---|
TRUE | NA |
TRUE | NA |
TRUE | 3 |
FALSE | NA |
FALSE | NA |
TRUE | 1 |
FALSE | NA |
FALSE | NA |
TRUE | NA |
TRUE | 2 |
Any ideas who this would be done efficiently using Polars?
[EDIT with a new approach]
I got it working with the code below, but hoping there is a more efficient way. Anyone know the default struct/dictionary field names from value_counts?
(
df.lazy()
.with_row_count()
.with_column(
pl.when(pl.col("value") == False).then(
pl.col("row_nr")
).fill_null(
strategy = "forward"
).alias("id_consecutive_Trues")
)
.with_column(
pl.col("id_consecutive_Trues").value_counts(sort = True)
)
.with_column(
(
pl.col("id_consecutive_Trues").arr.eval(
pl.element().struct().rename_fields(["value", "count"]).struct.field("count")
).arr.max()
- pl.lit(1)
).alias("max_consecutive_true_values")
)
.collect()
)
One possible definition of the problem is:
true
group, give me the group length.df.with_columns(
pl.when(pl.col("values") & pl.col("values").is_last_distinct())
.then(pl.len())
.over(pl.col("values").rle_id())
)
shape: (10, 2)
┌────────┬───────┐
│ values ┆ count │
│ --- ┆ --- │
│ bool ┆ u32 │
╞════════╪═══════╡
│ true ┆ null │
│ true ┆ null │
│ true ┆ 3 │
│ false ┆ null │
│ false ┆ null │
│ true ┆ 1 │
│ false ┆ null │
│ false ┆ null │
│ true ┆ null │
│ true ┆ 2 │
└────────┴───────┘
.rle_id()
gives us "group ids" for the consecutive values.
df.with_columns(group = pl.col("values").rle_id())
shape: (10, 2)
┌────────┬───────┐
│ values ┆ group │
│ --- ┆ --- │
│ bool ┆ u32 │
╞════════╪═══════╡
│ true ┆ 0 │
│ true ┆ 0 │
│ true ┆ 0 │
│ false ┆ 1 │
│ false ┆ 1 │
│ true ┆ 2 │
│ false ┆ 3 │
│ false ┆ 3 │
│ true ┆ 4 │
│ true ┆ 4 │
└────────┴───────┘
.is_last_distinct()
with the .over()
allows us to detect the last row of each group.
pl.len()
with .over()
gives us the number of rows in the group.