Search code examples
pythonpython-polars

How to filter polars dataframe by first maximum value while using over?


I am trying to filter a dataframe to find the first occurrence of a maximum value over a category column. In my data there is no guarantee that there is a single unique maximum value, there could be multiple values, but i only need the first occurance.

Yet I can't seem to find a way to limit the max part of the filter, currently I am then adding a further filter on another column generally a time based one and taking the minimum value.

df = pl.DataFrame(
    {
        "cat": [1, 1, 1, 2, 2, 2, 2, 3, 3, 3],
        "max_col": [12, 24, 36, 15, 50, 50, 45, 20, 40, 60],
        "other_col": [25, 50, 75, 125, 150, 175, 200, 225, 250, 275],
    }
)

df = df.filter(pl.col("max_col") == pl.col("max_col").max().over("cat")).filter(
    pl.col("other_col") == pl.col("other_col").min().over("cat")
)

shape: (3, 3)
┌─────┬─────────┬───────────┐
│ cat ┆ max_col ┆ other_col │
│ --- ┆ ---     ┆ ---       │
│ i64 ┆ i64     ┆ i64       │
╞═════╪═════════╪═══════════╡
│ 1   ┆ 36      ┆ 75        │
│ 2   ┆ 50      ┆ 150       │
│ 3   ┆ 60      ┆ 275       │
└─────┴─────────┴───────────┘

However, I'd prefer to simplify the above to only require passing in references to the max and category columns.

Am I missing something obvious here?

EDIT: Added example dataframe and output.


Solution

  • You can add .is_first_distinct() to the filter to keep only the first max.

    df.filter(
        pl.all_horizontal(
            pl.col("max_col") == pl.col("max_col").max(),
            pl.col("max_col").is_first_distinct()
        )
        .over("cat")
    )
    
    shape: (3, 3)
    ┌─────┬─────────┬───────────┐
    │ cat ┆ max_col ┆ other_col │
    │ --- ┆ ---     ┆ ---       │
    │ i64 ┆ i64     ┆ i64       │
    ╞═════╪═════════╪═══════════╡
    │ 1   ┆ 36      ┆ 75        │
    │ 2   ┆ 50      ┆ 150       │
    │ 3   ┆ 60      ┆ 275       │
    └─────┴─────────┴───────────┘