Search code examples
pythonpython-polars

How to filter polars dataframe by first maximum value while using over?


I am trying to filter a dataframe to find the first occurrence of a maximum value over a category column. In my data there is no guarantee that there is a single unique maximum value, there could be multiple values, but i only need the first occurance.

Yet I can't seem to find a way to limit the max part of the filter, currently I am then adding a further filter on another column generally a time based one and taking the minimum value.

df = pl.DataFrame(
    {
        "cat": [1, 1, 1, 2, 2, 2, 2, 3, 3, 3],
        "max_col": [12, 24, 36, 15, 50, 50, 45, 20, 40, 60],
        "other_col": [25, 50, 75, 125, 150, 175, 200, 225, 250, 275],
    }
)

df = df.filter(pl.col("max_col") == pl.col("max_col").max().over("cat")).filter(
    pl.col("other_col") == pl.col("other_col").min().over("cat")
)

shape: (3, 3)
┌─────┬─────────┬───────────┐
│ cat ┆ max_col ┆ other_col │
│ --- ┆ ---     ┆ ---       │
│ i64 ┆ i64     ┆ i64       │
╞═════╪═════════╪═══════════╡
│ 1   ┆ 36      ┆ 75        │
│ 2   ┆ 50      ┆ 150       │
│ 3   ┆ 60      ┆ 275       │
└─────┴─────────┴───────────┘

However, I'd prefer to simplify the above to only require passing in references to the max and category columns.

Am I missing something obvious here?

EDIT: Added example dataframe and output.


Solution

  • It sounds like you're asking for .arg_max() (and the value of other_col is actually irrelevant)

    df.filter(
       (pl.int_range(pl.len()) == pl.col.max_col.arg_max()).over("cat")
    )
    
    shape: (3, 3)
    ┌─────┬─────────┬───────────┐
    │ cat ┆ max_col ┆ other_col │
    │ --- ┆ ---     ┆ ---       │
    │ i64 ┆ i64     ┆ i64       │
    ╞═════╪═════════╪═══════════╡
    │ 1   ┆ 36      ┆ 75        │
    │ 2   ┆ 50      ┆ 150       │
    │ 3   ┆ 60      ┆ 275       │
    └─────┴─────────┴───────────┘
    

    Explanation

    The .int_range() gives us the row number per group, and you want the one that matches .arg_max()

    df.with_columns(
       row_number = pl.int_range(pl.len()).over("cat"),
       row_i_want = pl.col.max_col.arg_max().over("cat")
    )
    
    shape: (10, 5)
    ┌─────┬─────────┬───────────┬────────────┬────────────┐
    │ cat ┆ max_col ┆ other_col ┆ row_number ┆ row_i_want │
    │ --- ┆ ---     ┆ ---       ┆ ---        ┆ ---        │
    │ i64 ┆ i64     ┆ i64       ┆ i64        ┆ u32        │
    ╞═════╪═════════╪═══════════╪════════════╪════════════╡
    │ 1   ┆ 12      ┆ 25        ┆ 0          ┆ 2          │
    │ 1   ┆ 24      ┆ 50        ┆ 1          ┆ 2          │
    │ 1   ┆ 36      ┆ 75        ┆ 2          ┆ 2          │ # KEEP
    │ 2   ┆ 15      ┆ 125       ┆ 0          ┆ 1          │
    │ 2   ┆ 50      ┆ 150       ┆ 1          ┆ 1          │ # KEEP
    │ 2   ┆ 50      ┆ 175       ┆ 2          ┆ 1          │
    │ 2   ┆ 45      ┆ 200       ┆ 3          ┆ 1          │
    │ 3   ┆ 20      ┆ 225       ┆ 0          ┆ 2          │
    │ 3   ┆ 40      ┆ 250       ┆ 1          ┆ 2          │
    │ 3   ┆ 60      ┆ 275       ┆ 2          ┆ 2          │ # KEEP
    └─────┴─────────┴───────────┴────────────┴────────────┘