Search code examples
pythonpython-polarsrolling-computation

Python Polars: Calculate rolling mode over multiple columns


I have a polars.DataFrame like:

data = pl.DataFrame({
"col1": [3, 2, 4, 7, 1, 10, 7],
"col2": [3, 4, None, 1, None, 1, 9],
"col3": [3, 1, None, None, None, None, 4],
"col4": [None, 5, None, None, None, None, None],
"col5": [None, None, None, None, None, None, None]})

┌──────┬──────┬──────┬──────┬──────┐
│ col1 ┆ col2 ┆ col3 ┆ col4 ┆ col5 │
│ ---  ┆ ---  ┆ ---  ┆ ---  ┆ ---  │
│ i64  ┆ i64  ┆ i64  ┆ i64  ┆ f32  │
╞══════╪══════╪══════╪══════╪══════╡
│ 3    ┆ 3    ┆ 3    ┆ null ┆ null │
│ 2    ┆ 4    ┆ 1    ┆ 5    ┆ null │
│ 4    ┆ null ┆ null ┆ null ┆ null │
│ 7    ┆ 1    ┆ null ┆ null ┆ null │
│ 1    ┆ null ┆ null ┆ null ┆ null │
│ 10   ┆ 1    ┆ null ┆ null ┆ null │
│ 7    ┆ 9    ┆ 4    ┆ null ┆ null │
└──────┴──────┴──────┴──────┴──────┘

I want to create a new column that contains the rolling mode - but not based on one column and the respective row values within the window but on row values of all columns within the window. The nulls should be dropped and shouldn't appear in the resulting columns as a mode value.

edit:

I made some changes to the example data provided. For further clarifications and under the assumption of something like polars.rolling_apply(<function>, window_size=2, min_periods=1, center=False) I would expect the following result:

┌──────┐
│ res  │
│ ---  │
│ i64  │
╞══════╡
│ 3    │
│ 3    │
│ 4    │
│ None │ <- all values different
│ 1    │
│ 1    │
│ None │ <- all values different
└──────┘

In case there is no mode None as a result would be fine. Only the missing value in the original polars.DataFrame should be ignored.


Solution

  • .rolling() can be used to aggregate over the windows.

    Using .concat_list() inside .agg() will give us a nested list, e.g.

    • [[col1, col2, ...], [col1, col2, ...]]

    Which we can flatten, remove nulls, and calculate the mode.

    (df.with_row_index()
       .rolling(
          index_column = "index",
          period = "2i"
       )
       .agg(
          pl.concat_list(pl.exclude("index")).flatten().drop_nulls().mode()
            .alias("mode")
       )
    #   .with_columns(
    #      pl.when(pl.col("mode").list.len() == 1)
    #        .then(pl.col("mode").list.first())
    #   )
    )
    
    shape: (7, 2)
    ┌───────┬──────────────────┐
    │ index ┆ mode             │
    │ ---   ┆ ---              │
    │ u32   ┆ list[i64]        │
    ╞═══════╪══════════════════╡
    │ 0     ┆ [3]              │
    │ 1     ┆ [3]              │
    │ 2     ┆ [4]              │
    │ 3     ┆ [1, 7, 4]        │
    │ 4     ┆ [1]              │
    │ 5     ┆ [1]              │
    │ 6     ┆ [10, 7, 9, 1, 4] │
    └───────┴──────────────────┘
    

    The commented out lines deal with discarding ties and getting rid of the list.

    shape: (7, 2)
    ┌───────┬──────┐
    │ index ┆ mode │
    │ ---   ┆ ---  │
    │ u32   ┆ i64  │
    ╞═══════╪══════╡
    │ 0     ┆ 3    │
    │ 1     ┆ 3    │
    │ 2     ┆ 4    │
    │ 3     ┆ null │
    │ 4     ┆ 1    │
    │ 5     ┆ 1    │
    │ 6     ┆ null │
    └───────┴──────┘