Search code examples
rustrust-polars

How to extract the value from nested list value?


Given an aggregated dataframe and an index dataframe, how to extract data from the list[<>]?

┌──────┬────────────┬───────────┐
│ read ┆ region     ┆ cov       │
│ ---  ┆ ---        ┆ ---       │
│ str  ┆ list[str]  ┆ list[i32] │
╞══════╪════════════╪═══════════╡ df
│ a    ┆ ["x", "y"] ┆ [25, 10]  │
│ b    ┆ ["x", "z"] ┆ [15, 30]  │
└──────┴────────────┴───────────┘

┌──────┬─────────┐
│ read ┆ cov_idx │
│ ---  ┆ ---     │
│ str  ┆ u32     │
╞══════╪═════════╡ df_idx
│ a    ┆ 0       │
│ b    ┆ 1       │
└──────┴─────────┘

The following code is an example used to generate the dataframe. The cov is actually generated by a complicated function and df_idx is derived from that (arg_max).

use polars::prelude::*;
use polars::df;

fn main() -> PolarsResult<()> {
    let df0 = df![
        "read" => ["a", "a", "b", "b"],
        "region" => ["x", "y", "x", "y"],
        "cov" => [25, 10, 15, 30]
    ]?;
    let df = df0.lazy()
        .group_stable([col("read")])
        .agg([col("*")])
        .collect()?;
    let df_idx = df![
        "read" => ["a", "b"],
        "cov_idx" => [0, 1]
    ]?;
}

The expected result is

┌──────┬───────────┬───────────┐
│ read ┆ region    ┆ cov       │
│ ---  ┆ ---       ┆ ---       │
│ str  ┆ list[str] ┆ list[i32] │
╞══════╪═══════════╪═══════════╡ result_df
│ a    ┆ "x"       ┆ 25        │
│ b    ┆ "z"       ┆ 30        │
└──────┴───────────┴───────────┘ 

Solution

  • You can use the Lazy API’s expr.list().get(idx) to fetch the idxth element from each list.

    fn main() -> PolarsResult<()> {
        let df0 = df![
            "read" => ["a", "a", "b", "b"],
            "region" => ["x", "y", "x", "y"],
            "cov" => [25, 10, 15, 30]
        ]?;
        let df_idx = df![
            "read" => ["a", "b"],
            "cov_idx" => [0, 1]
        ]?
        .lazy();
    
        let df = df0
            .lazy()
            .groupby_stable([col("read")])
            .agg([col("*")])
            .left_join(df_idx, col("read"), col("read"))
            .with_columns(["region", "cov"].map(|c| col(c).list().get(col("cov_idx"))));
    
        println!("{:?}", df.collect()?);
    
        Ok(())
    }
    

    Result:

    shape: (2, 4)
    ┌──────┬────────┬─────┬─────────┐
    │ read ┆ region ┆ cov ┆ cov_idx │
    │ ---  ┆ ---    ┆ --- ┆ ---     │
    │ str  ┆ str    ┆ i32 ┆ i32     │
    ╞══════╪════════╪═════╪═════════╡
    │ a    ┆ x      ┆ 25  ┆ 0       │
    │ b    ┆ y      ┆ 30  ┆ 1       │
    └──────┴────────┴─────┴─────────┘