Search code examples
rustrust-polars

How can I apply a custom function and return named fields (Struct) in Rust Polars


Below is a simple example of a groupby-agg operation where I want to return an array/vector of min and max values of each group as a single column in the result.

#[pyfunction]
fn test_fn(pydf: PyDataFrame, colnm: &str, by_cols: Vec<&str>) -> PyResult<PyDataFrame> {
    let df: DataFrame = pydf.into();
    let res = df
        .lazy()
        .groupby(by_cols)
        .agg([col(colnm).apply(
            |s| {
                let v: Vec<f64> = vec![s.min().unwrap(), s.max().unwrap()];
                Ok(Some(Series::new("s", v)))
            },
            GetOutput::default(),
        )])
        .collect()
        .map_err(PyPolarsErr::from)?;
    Ok(PyDataFrame(res))
}

#[pymodule]
fn test_module(_py: Python, m: &PyModule) -> PyResult<()> {
    m.add_function(wrap_pyfunction!(test_fn, m)?)?;
    Ok(())
}

As you can see from the following code section, column a in the resulting dataframe contains a list of two elements (min and max).

import polars as pl
import test_module

df = pl.DataFrame(
    {"a": [1.0, 2.0, 3.0, 4.0, 5.0], "g1": [1, 1, 2, 2, 2], "g2": [1, 1, 1, 2, 2]}
)

>>> test_module.test_fn(df, "a", ["g1", "g2"])
shape: (3, 3)
┌─────┬─────┬────────────┐
│ g1  ┆ g2  ┆ a          │
│ --- ┆ --- ┆ ---        │
│ i64 ┆ i64 ┆ list[f64]  │
╞═════╪═════╪════════════╡
│ 1   ┆ 1   ┆ [1.0, 2.0] │
│ 2   ┆ 2   ┆ [4.0, 5.0] │
│ 2   ┆ 1   ┆ [3.0, 3.0] │
└─────┴─────┴────────────┘

Now, I am curious how can I modify my test_fn above to make it return a struct/dict/hashmap instead of a vector, with the benefit of having named fields in the result?

More specifically, what I want is:

>>> test_module.test_fn(df, "a", ["g1", "g2"])
shape: (3, 3)
┌─────┬─────┬───────────┐
│ g1  ┆ g2  ┆ a         │
│ --- ┆ --- ┆ ---       │
│ i64 ┆ i64 ┆ struct[2] │
╞═════╪═════╪═══════════╡
│ 1   ┆ 1   ┆ {1.0,2.0} │
│ 2   ┆ 2   ┆ {4.0,5.0} │
│ 2   ┆ 1   ┆ {3.0,3.0} │
└─────┴─────┴───────────┘

Or

>>> test_module.test_fn(df, "a", ["g1", "g2"])
shape: (3, 4)
┌─────┬─────┬───────┬───────┐
│ g1  ┆ g2  ┆ a_min ┆ a_max │
│ --- ┆ --- ┆ ---   ┆ ---   │
│ i64 ┆ i64 ┆ f64   ┆ f64   │
╞═════╪═════╪═══════╪═══════╡
│ 2   ┆ 1   ┆ 3.0   ┆ 3.0   │
│ 2   ┆ 2   ┆ 4.0   ┆ 5.0   │
│ 1   ┆ 1   ┆ 1.0   ┆ 2.0   │
└─────┴─────┴───────┴───────┘

Solution

  • .agg([ ... ]) here returns a list, but we can make it work for a struct. We'll have to define the struct with named fields of the output we desire. Here's how you can return a struct:

    #[derive(Debug, Clone)]
    struct AggResult {
        a_min: f64,
        a_max: f64,
    }
    
    #[pyfunction]
    fn test_fn(pydf: PyDataFrame, colnm: &str, by_cols: Vec<&str>) -> PyResult<PyDataFrame> {
        let df: DataFrame = pydf.into();
        let res = df
            .lazy()
            .groupby(by_cols)
            .agg([
                col(colnm)
                    .apply(
                        |s| {
                            let agg_result = AggResult {
                                a_min: s.min().unwrap(),
                                a_max: s.max().unwrap(),
                            };
                            Ok(Some(Series::new(
                                "agg_result",
                                vec![agg_result],
                            )))
                        },
                        GetOutput::default(),
                    )
            ])
            .collect()
            .map_err(PyPolarsErr::from)?;
        Ok(PyDataFrame(res))
    }
    
    #[pymodule]
    fn test_module(_py: Python, m: &PyModule) -> PyResult<()> {
        m.add_class::<AggResult>()?;
        m.add_function(wrap_pyfunction!(test_fn, m)?)?;
        Ok(())
    }
    

    So, now it should give (1) as you wanted for:

    import polars as pl
    import test_module
    
    df = pl.DataFrame(
            {"a": [1.0, 2.0, 3.0, 4.0, 5.0], "g1": [1, 1, 2, 2, 2], "g2": [1, 1, 1, 2, 2]}
        )
    
    >>> test_module.test_fn(df, "a", ["g1", "g2"])
    shape: (3, 3)
    ┌─────┬─────┬──────────────────────────────────────────────────┐
    │ g1  ┆ g2  ┆ agg_result                                       │
    │ --- ┆ --- ┆ ---                                              │
    │ i64 ┆ i64 ┆ struct[<a_min: f64, a_max: f64>]                 │
    ╞═════╪═════╪══════════════════════════════════════════════════╡
    │ 1   ┆ 1   ┆ <AggResult { a_min: 1.0, a_max: 2.0 }>           │
    │ 2   ┆ 1   ┆ <AggResult { a_min: 3.0, a_max: 3.0 }>           │
    │ 2   ┆ 2   ┆ <AggResult { a_min: 4.0, a_max: 5.0 }>           │
    └─────┴─────┴──────────────────────────────────────────────────┘