Search code examples
pythondataframepython-polars

Count same consecutive numbers in list column in polars dataframe


I have a pl.DataFrame with a column comprising lists with integers. I need to assert that each consecutive integer is showing up two times in a row at a maximum.

For instance, a list containing [1,1,0,-1,1] would be OK, as the number 1 is showing up max two times in a row (the first two elements, followed by a zero).

This list should lead to a failed assertion: [1,1,1,0,-1] The number 1 shows up three times in a row.

Here's a toy example, where row2 should lead to a failed assertion.

import polars as pl

row1 = [0, 1, -1, -1, 1, 1, -1, 0]
row2 = [1, -1, -1, -1, 0, 0, 1, -1]
df = pl.DataFrame({"list": [row1, row2]})
print(f"row1: {row1}")
print(f"row2: {row2}")
print(df)


row1: [0, 1, -1, -1, 1, 1, -1, 0]
row2: [1, -1, -1, -1, 0, 0, 1, -1]
shape: (2, 1)
┌───────────────┐
│ list          │
│ ---           │
│ list[i64]     │
╞═══════════════╡
│ [0, 1, … 0]   │
│ [1, -1, … -1] │
└───────────────┘

Solution

  • The following could be used.

    1. Perform run-length encoding of the list using pl.Expr.rle. This produces a list of structs. Each struct contains a (unique) list value and the corresponding run length.

    2. Check whether the maximum run length in the list is at most 2.

    3. Ensure the result is of type bool by selecting the first (and only) element in the resulting list (using pl.Expr.list.first).

    df.with_columns(
        ok=pl.col("list").list.eval(
            pl.element().rle().struct.field("len").max() <= 2
        ).list.first()
    )
    
    shape: (2, 2)
    ┌──────────────────────────────┬───────┐
    │ list                         ┆ ok    │
    │ ---                          ┆ ---   │
    │ list[i64]                    ┆ bool  │
    ╞══════════════════════════════╪═══════╡
    │ [0, 1, -1, -1, 1, 1, -1, 0]  ┆ true  │
    │ [1, -1, -1, -1, 0, 0, 1, -1] ┆ false │
    └──────────────────────────────┴───────┘
    

    Doing all of this in a pl.Expr.list.eval can be avoided by exploding the list. Then, a window function (pl.Expr.over) is needed to ensure the maximum is computed separately for each list.

    max_run_length = pl.col("list").explode().rle().struct.field("len").max().over(pl.int_range(pl.len()))
    df.with_columns(passed=max_run_length <= 2)
    

    The result will be same.