Search code examples
pythonpython-polars

Find value of column based on another column condition (max) in polars for many columns


If I have this dataframe:

pl.DataFrame(dict(x=[0, 1, 2, 3], y=[5, 2, 3, 3],z=[4,7,8,2]))
shape: (4, 3)
┌─────┬─────┬─────┐
│ x   ┆ y   ┆ z   │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ i64 │
╞═════╪═════╪═════╡
│ 0   ┆ 5   ┆ 4   │
│ 1   ┆ 2   ┆ 7   │
│ 2   ┆ 3   ┆ 8   │
│ 3   ┆ 3   ┆ 2   │
└─────┴─────┴─────┘

and I want to find the value in x where y is max, then again find the value in x where z is max, and repeat for hundreds more columns so that I end up with something like:

shape: (2, 2)
┌────────┬─────────┐
│ column ┆ x_value │
│ ---    ┆ ---     │
│ str    ┆ i64     │
╞════════╪═════════╡
│ y      ┆ 0       │
│ z      ┆ 2       │
└────────┴─────────┘

or

shape: (1, 2)
┌─────┬─────┐
│ y   ┆ z   │
│ --- ┆ --- │
│ i64 ┆ i64 │
╞═════╪═════╡
│ 0   ┆ 2   │
└─────┴─────┘

What is the best polars way to do that?


Solution

  • Update: Expr.top_k_by() has since been added to Polars.

    df.select(
       y = pl.col("x").top_k_by("y", k=1),
       z = pl.col("x").top_k_by("z", k=1)
    )
    
    shape: (1, 2)
    ┌─────┬─────┐
    │ y   ┆ z   │
    │ --- ┆ --- │
    │ i64 ┆ i64 │
    ╞═════╪═════╡
    │ 0   ┆ 2   │
    └─────┴─────┘
    

    You could reshape with .unpivot()

    df.unpivot(index="x")
    
    shape: (8, 3)
    ┌─────┬──────────┬───────┐
    │ x   ┆ variable ┆ value │
    │ --- ┆ ---      ┆ ---   │
    │ i64 ┆ str      ┆ i64   │
    ╞═════╪══════════╪═══════╡
    │ 0   ┆ y        ┆ 5     │
    │ 1   ┆ y        ┆ 2     │
    │ 2   ┆ y        ┆ 3     │
    │ 3   ┆ y        ┆ 3     │
    │ 0   ┆ z        ┆ 4     │
    │ 1   ┆ z        ┆ 7     │
    │ 2   ┆ z        ┆ 8     │
    │ 3   ┆ z        ┆ 2     │
    └─────┴──────────┴───────┘
    

    And .filter()

    (df.unpivot(index="x")
       .filter(pl.col.value == pl.col.value.max().over("variable"))
    )
    
    shape: (2, 3)
    ┌─────┬──────────┬───────┐
    │ x   ┆ variable ┆ value │
    │ --- ┆ ---      ┆ ---   │
    │ i64 ┆ str      ┆ i64   │
    ╞═════╪══════════╪═══════╡
    │ 0   ┆ y        ┆ 5     │
    │ 2   ┆ z        ┆ 8     │
    └─────┴──────────┴───────┘
    

    If there are multiple max values in a group, you will get multiple rows.

    .arg_max() could be used if only a single max is to be chosen.

    (df.unpivot(index="x")
       .group_by("variable")
       .agg(
          pl.all().get(pl.col.value.arg_max())
       )
    )