Search code examples
dataframepython-polars

(Polars) How to get element from a column with list by index specified in another column


I have a dataframe with 2 columns, where first column contains lists, and second column integer indexes.

How to get elements from first column by index specified in second column? Or even better, put that element in 3rd column.

Input example

df = pl.DataFrame({
    "lst": [[1, 2, 3], [4, 5, 6]], 
    "ind": [1, 2]
})
┌───────────┬─────┐
│ lst       ┆ ind │
│ ---       ┆ --- │
│ list[i64] ┆ i64 │
╞═══════════╪═════╡
│ [1, 2, 3] ┆ 1   │
│ [4, 5, 6] ┆ 2   │
└───────────┴─────┘

Expected output.

res = df.with_columns(pl.Series("list[ind]", [2, 6]))
┌───────────┬─────┬───────────┐
│ lst       ┆ ind ┆ list[ind] │
│ ---       ┆ --- ┆ ---       │
│ list[i64] ┆ i64 ┆ i64       │
╞═══════════╪═════╪═══════════╡
│ [1, 2, 3] ┆ 1   ┆ 2         │
│ [4, 5, 6] ┆ 2   ┆ 6         │
└───────────┴─────┴───────────┘

Thanks.


Solution

  • Update: This can now be done more easily by

    df.with_columns(pl.col("lst").list.get(pl.col("ind")).alias("list[ind]"))
    

    Original answer

    You can use with_row_index() to add a row index column for grouping, then explode() the list so each list element is on each row. Then call gather() over the row index column using over() to select the element from the subgroup.

    df = pl.DataFrame({"lst": [[1, 2, 3], [4, 5, 6]], "ind": [1, 2]})
    
    df = (
        df.with_row_index()
        .with_columns(
            pl.col("lst").explode().gather(pl.col("ind")).over(pl.col("index")).alias("list[ind]")
        )
        .drop("index")
    )
    
    shape: (2, 3)
    ┌───────────┬─────┬───────────┐
    │ lst       ┆ ind ┆ list[ind] │
    │ ---       ┆ --- ┆ ---       │
    │ list[i64] ┆ i64 ┆ i64       │
    ╞═══════════╪═════╪═══════════╡
    │ [1, 2, 3] ┆ 1   ┆ 2         │
    │ [4, 5, 6] ┆ 2   ┆ 6         │
    └───────────┴─────┴───────────┘