Search code examples
pythonpython-polars

How to get an index of max value element from a list column in python polars?


I have a data frame:

pl.DataFrame({'no_of_docs':[[9,4,2],
                            [3,9,1,10],
                            [10,3,2,1],
                            [10,30],
                            [1,2,3,6,4,5]]})

Here the column: no_of_docs is a list[int] type one: i would like to add a new column with the max value index from each list?

Another case:

pl.DataFrame({'no_of_docs':[['9','4','2'],
                            ['3','9','1','10'],
                            ['10','3','2','1'],
                            ['10','30'],
                            ['1','2','3','6','4','5']]})

Here no_of_docs is a type of list[str] and how to convert it to int and get an index of max value.

Expected output for the list[str] example:

# increase repr defaults
pl.Config(
    fmt_table_cell_list_len=10, 
    fmt_str_lengths=80
)

print(result)
shape: (5, 2)
┌────────────────────────────────┬─────┐
│ no_of_docs                     ┆ idx │
│ ---                            ┆ --- │
│ list[str]                      ┆ u32 │
╞════════════════════════════════╪═════╡
│ ["9", "4", "2"]                ┆ 0   │
│ ["3", "9", "1", "10"]          ┆ 3   │
│ ["10", "3", "2", "1"]          ┆ 0   │
│ ["10", "30"]                   ┆ 1   │
│ ["1", "2", "3", "6", "4", "5"] ┆ 3   │
└────────────────────────────────┴─────┘

Solution

  • I mean you answered the question mostly yourself, but in case you still need the casting to List[i64]. Here would be the solution

    df.with_columns(
        pl.col("no_of_docs").cast(pl.List(pl.Int64)).list.arg_max().alias('idx')
    )
    
    shape: (5, 2)
    ┌────────────────────────────────┬─────┐
    │ no_of_docs                     ┆ idx │
    │ ---                            ┆ --- │
    │ list[str]                      ┆ u32 │
    ╞════════════════════════════════╪═════╡
    │ ["9", "4", "2"]                ┆ 0   │
    │ ["3", "9", "1", "10"]          ┆ 3   │
    │ ["10", "3", "2", "1"]          ┆ 0   │
    │ ["10", "30"]                   ┆ 1   │
    │ ["1", "2", "3", "6", "4", "5"] ┆ 3   │
    └────────────────────────────────┴─────┘