Search code examples
pythondataframepython-polars

Python Polars: Number of rows since last value >0


Given a polars DataFrame column like

df = pl.DataFrame({"a": [0, 29, 28, 4, 0, 0, 13, 0]})

how to get a new column like

shape: (8, 2)
┌─────┬──────┐
│ a   ┆ dist │
│ --- ┆ ---  │
│ i64 ┆ i64  │
╞═════╪══════╡
│ 0   ┆ 1    │
│ 29  ┆ 0    │
│ 28  ┆ 0    │
│ 4   ┆ 0    │
│ 0   ┆ 1    │
│ 0   ┆ 2    │
│ 13  ┆ 0    │
│ 0   ┆ 1    │
└─────┴──────┘

The solution should preferably work with .over() for grouped values and optionally an additional rolling window function like rolling_mean().

I know of the respective question for pandas but couldn't manage to translate it.


Solution

  • Here's one way with rle_id to identify the groups to project over, and only doing so on the 0 groups with a when/then:

    df = pl.DataFrame({"a": [0, 29, 28, 4, 0, 0, 13, 0]})
    
    df.with_columns(
        dist=pl.when(pl.col('a') == 0)
        .then(pl.col('a').cum_count().over(pl.col('a').ne(0).rle_id()))
        .otherwise(0)
    )
    
    shape: (8, 2)
    ┌─────┬──────┐
    │ a   ┆ dist │
    │ --- ┆ ---  │
    │ i64 ┆ u32  │
    ╞═════╪══════╡
    │ 0   ┆ 1    │
    │ 29  ┆ 0    │
    │ 28  ┆ 0    │
    │ 4   ┆ 0    │
    │ 0   ┆ 1    │
    │ 0   ┆ 2    │
    │ 13  ┆ 0    │
    │ 0   ┆ 1    │
    └─────┴──────┘