Search code examples
pythonpython-polars

Polars keep the biggest value using 2 categories


I have a polars dataframe that contain some ID, actions, and values :

Example Dataframe:

data = {
    "ID" : [1, 1, 2,2,3,3],
    "Action" : ["A", "A", "B", "B", "A", "A"],
    "Where" : ["Office", "Home", "Home", "Office", "Home", "Home"],
    "Value" : [1, 2, 3, 4, 5, 6]
}

df = pl.DataFrame(data)

I want to select for each ID and action the biggest value, so i know where he rather do the action.

I'm taking the following approach :

(
    df
    .select(
        pl.col("ID"),
        pl.col("Action"),
        pl.col("Where"),
        TOP = pl.col("Value").max().over(["ID", "Action"]))
)

After that , i sorted the values and keep the unique values (The first one) to maintain the desired info, however the input its incorrect :

(
    df
    .select(
        pl.col("ID"),
        pl.col("Action"),
        pl.col("Where"),
        TOP = pl.col("Value").max().over(["ID", "Action"]))
    .sort(
        pl.col("*"), descending =True
    )
    .unique(
        subset = ["ID", "Action"],
        maintain_order = True,
        keep = "first"
    )
)

Current Output :

shape: (3, 4)
┌─────┬────────┬────────┬─────┐
│ ID  ┆ Action ┆ Where  ┆ TOP │
│ --- ┆ ---    ┆ ---    ┆ --- │
│ i64 ┆ str    ┆ str    ┆ i64 │
╞═════╪════════╪════════╪═════╡
│ 3   ┆ A      ┆ Home   ┆ 6   │
│ 2   ┆ B      ┆ Office ┆ 4   │
│ 1   ┆ A      ┆ Office ┆ 2   │
└─────┴────────┴────────┴─────┘

Expected Output:

shape: (3, 4)
┌─────┬────────┬────────┬─────┐
│ ID  ┆ Action ┆ Where  ┆ TOP │
│ --- ┆ ---    ┆ ---    ┆ --- │
│ i64 ┆ str    ┆ str    ┆ i64 │
╞═════╪════════╪════════╪═════╡
│ 3   ┆ A      ┆ Home   ┆ 6   │
│ 2   ┆ B      ┆ Office ┆ 4   │
│ 1   ┆ A      ┆ Home   ┆ 2   │
└─────┴────────┴────────┴─────┘

Also, i think this approach its not the optimal way


Solution

  • The over and unique could be combined into a group_by

    • .arg_max() can give you the index of the max
    • .get() will extract the corresponding values at that index
    (df.group_by("ID", "Action")
       .agg(
          pl.all().get(pl.col("Value").arg_max())
       )
    )
    
    shape: (3, 4)
    ┌─────┬────────┬────────┬───────┐
    │ ID  ┆ Action ┆ Where  ┆ Value │
    │ --- ┆ ---    ┆ ---    ┆ ---   │
    │ i64 ┆ str    ┆ str    ┆ i64   │
    ╞═════╪════════╪════════╪═══════╡
    │ 1   ┆ A      ┆ Home   ┆ 2     │
    │ 2   ┆ B      ┆ Office ┆ 4     │
    │ 3   ┆ A      ┆ Home   ┆ 6     │
    └─────┴────────┴────────┴───────┘