Search code examples
rustrust-polars

Rust polars : unexpected befaviour of when().then().otherwise() in groupby-agg context


I have a complicated mapping logic which I seek to execute within groupby context. The code compiles and doesn't panic, but results are incorrect. I know the logic implementation is correct. Hence, I wonder if when-then-otherwise is supposed to be used within groupby at all?

use polars::prelude::*;
use polars::df;

fn main() {
    let df = df! [
        "Region" => ["EU", "EU", "EU", "EU", "EU"],
        "MonthCCY" => ["APRUSD", "MAYUSD", "JUNEUR", "JULUSD", "APRUSD"],
        "values" => [1, 2, 3, 4, 5],
    ].unwrap();

    let df = df.lazy()
        .groupby_stable([col("MonthCCY")])
        .agg( [
            month_weight().alias("Weight"),
        ]
        );
}

pub fn month_weight() -> Expr {
    when(col("Region").eq(lit("EU")))
    .then(
        // First, If MonthCCY is JUNEUR or AUGEUR - apply 0.05
        when(col("MonthCCY").map( |s|{
            Ok( s.utf8()?
            .contains("JUNEUR|AUGEUR")?
            .into_series() )
         }
            , GetOutput::from_type(DataType::Boolean)
        ))
        .then(lit::<f64>(0.05))
        .otherwise(
            // Second, If MonthCCY is JANEUR - apply 0.0225
            when(col("MonthCCY").map( |s|{
                Ok( s.utf8()?
                .contains("JANEUR")?
                .into_series() )
             }
                , GetOutput::from_type(DataType::Boolean)
            ))
            .then(lit::<f64>(0.0225))
            .otherwise(
                // Third, If MonthCCY starts with JUL or FEB (eg FEBUSD or FEBEUR)- apply 0.15
                when(col("MonthCCY").apply( |s|{
                    let x = s.utf8()?
                    .str_slice(0, Some(3))?;
                    let y = x.contains("JUL|FEB")?
                    .into_series();
                    Ok(y)
                 }
                    , GetOutput::from_type(DataType::Boolean)
                ))
                .then(lit::<f64>(0.15))
                //Finally, if none of the above matched, apply 0.2
                .otherwise(lit::<f64>(0.20))
            )
        )
    ).otherwise(lit::<f64>(0.0))
}

The result I am getting is:

┌──────────┬─────────────┐
│ MonthCCY ┆ Weight      │
│ ---      ┆ ---         │
│ str      ┆ list [f64]  │
╞══════════╪═════════════╡
│ APRUSD   ┆ [0.2, 0.15] │
├╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ MAYUSD   ┆ [0.2]       │
├╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ JUNEUR   ┆ [0.05]      │
├╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ JULUSD   ┆ [0.2]       │
└──────────┴─────────────┘

Clearly, I would expect JULUSD to be [0.15] and APRUSD to be [0.2, 0.2].

Is my expectation of how when().then().otherwise() works within groupby wrong?

I am on Windows11, rustc 1.60.


Solution

  • Yep, you're doing the groupby and the mapping in the wrong order. month_weight() is not an aggregation expression but a simple mapping expression.

    As it is, each group of the DataFrame is getting agged into a series that ultimately derives from order of the data in the original frame.

    You first want to create a Weight column with values given by the mapping you specify in month_weight(), and then you want to aggregate this column into a list for each group.

    So, what you want is the following:

    let df = df
        .lazy()
        .with_column(month_weight().alias("Weight")) // create new column first
        .groupby_stable([col("MonthCCY")]) // then group
        .agg([col("Weight").list()]); // then collect into a list per group
    
    println!("{:?}", df.collect().unwrap());
    

    Output:

    shape: (4, 2)
    ┌──────────┬────────────┐
    │ MonthCCY ┆ Weight     │
    │ ---      ┆ ---        │
    │ str      ┆ list [f64] │
    ╞══════════╪════════════╡
    │ APRUSD   ┆ [0.2, 0.2] │
    ├╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌┤
    │ MAYUSD   ┆ [0.2]      │
    ├╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌┤
    │ JUNEUR   ┆ [0.05]     │
    ├╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌┤
    │ JULUSD   ┆ [0.15]     │
    └──────────┴────────────┘
    

    Also, as an aside, .when().then() can be chained indefinitely; you don't need to nest them. So just as you can write a chained if ... else if ... else if ... else, you can write col().when().then().when().then() ... .otherwise(), which is a lot simpler than nesting each additional condition.