Search code examples
pythonpython-polars

Python-Polars: Cross field calculation of struct columns


I am trying to buld a function that takes a list of struct columns, extracts two fields, and perform a cross-field combination of all the values of such fields. Everything in the same context. For example, for:

df = pl.DataFrame({
    'cases': ['case_1', 'case_2'],
    'value_1': [
        {'value_decimal': 22.5, 'value_decimal_lb': 22.0, 'value_decimal_ub': 23.0},
        {'value_decimal': 23.0, 'value_decimal_lb': 22.5, 'value_decimal_ub': 23.5}
    ],
    'value_2': [
        {'value_decimal': 5.0, 'value_decimal_lb': 5.0, 'value_decimal_ub': 5.0},
        {'value_decimal': 5.0, 'value_decimal_lb': 6.0, 'value_decimal_ub': 5.0}
    ]
})

print(df)

shape: (2, 3)
┌────────┬──────────────────┬───────────────┐
│ cases  ┆ value_1          ┆ value_2       │
│ ---    ┆ ---              ┆ ---           │
│ str    ┆ struct[3]        ┆ struct[3]     │
╞════════╪══════════════════╪═══════════════╡
│ case_1 ┆ {22.5,22.0,23.0} ┆ {5.0,5.0,5.0} │
│ case_2 ┆ {23.0,22.5,23.5} ┆ {5.0,6.0,5.0} │
└────────┴──────────────────┴───────────────┘

I would like to get:

shape: (2, 4)
┌────────┬──────────────────┬───────────────┬─────────────────────────────────┐
│ cases  ┆ value_1          ┆ value_2       ┆ result                          │
│ ---    ┆ ---              ┆ ---           ┆ ---                             │
│ str    ┆ struct[3]        ┆ struct[3]     ┆ list[list[f64]]                 │
╞════════╪══════════════════╪═══════════════╪═════════════════════════════════╡
│ case_1 ┆ {22.5,22.0,23.0} ┆ {5.0,5.0,5.0} ┆ [[22.0, 5.0], [22.0, 5.0], … [… │
│ case_2 ┆ {23.0,22.5,23.5} ┆ {5.0,5.0,5.0} ┆ [[22.5, 5.0], [22.5, 6.0], … [… │
└────────┴──────────────────┴───────────────┴─────────────────────────────────┘

This is something I can achieve nesting map_elements():

all_combinations = pl.struct(
    [
    # Extract the [lb, ub] lists for each column
    col.map_elements(lambda x: [x['value_decimal_lb'], x['value_decimal_ub']]).alias(f'{col}_list')
    for col in [pl.col('value_1'), pl.col('value_2')]
    ]
).map_elements(
    # Compute all combinations
    lambda x: [
        list(combo) for combo in product(*[x[f'{col}_list'] for col in [pl.col('value_1'), pl.col('value_2')]])
    ])

print(df.with_columns(all_combinations))

However, this solution is sub-optimal, and I need to know in advance the return_dtype of the column, which may differ depending on the input columns.

Would you have any sugggestion on how to implement this solution without using map_elements?


Solution

  • (
        df.with_columns(
            result1 = pl.concat_list(pl.col.value_1.struct.field("value_decimal_lb", "value_decimal_ub")),
            result2 = pl.concat_list(pl.col.value_2.struct.field("value_decimal_lb", "value_decimal_ub"))
        )
        .explode("result1")
        .explode("result2")
        .with_columns(result = pl.concat_list("result1","result2"))
        .group_by("cases")
        .agg("result", pl.col("value_1","value_2").first())
    )
    
    shape: (2, 4)
    ┌────────┬──────────────────┬───────────────┬─────────────────────────────────┐
    │ cases  ┆ value_1          ┆ value_2       ┆ result                          │
    │ ---    ┆ ---              ┆ ---           ┆ ---                             │
    │ str    ┆ struct[3]        ┆ struct[3]     ┆ list[list[f64]]                 │
    ╞════════╪══════════════════╪═══════════════╪═════════════════════════════════╡
    │ case_2 ┆ {23.0,22.5,23.5} ┆ {5.0,6.0,5.0} ┆ [[22.5, 6.0], [22.5, 5.0], … [… │
    │ case_1 ┆ {22.5,22.0,23.0} ┆ {5.0,5.0,5.0} ┆ [[22.0, 5.0], [22.0, 5.0], … [… │
    └────────┴──────────────────┴───────────────┴─────────────────────────────────┘
    

    Or something like this:

    df.with_columns(
        result = pl.concat_list(
            pl.struct(
                f1 = pl.col.value_1.struct.field(x_1),
                f2 = pl.col.value_2.struct.field(x_2)
            )
            for x_1 in ["value_decimal_lb", "value_decimal_ub"]
            for x_2 in ["value_decimal_lb", "value_decimal_ub"]
        )
    )
    
    shape: (2, 4)
    ┌────────┬──────────────────┬───────────────┬─────────────────────────────────┐
    │ cases  ┆ value_1          ┆ value_2       ┆ result                          │
    │ ---    ┆ ---              ┆ ---           ┆ ---                             │
    │ str    ┆ struct[3]        ┆ struct[3]     ┆ list[struct[2]]                 │
    ╞════════╪══════════════════╪═══════════════╪═════════════════════════════════╡
    │ case_1 ┆ {22.5,22.0,23.0} ┆ {5.0,5.0,5.0} ┆ [{22.0,5.0}, {22.0,5.0}, … {23… │
    │ case_2 ┆ {23.0,22.5,23.5} ┆ {5.0,6.0,5.0} ┆ [{22.5,6.0}, {22.5,5.0}, … {23… │
    └────────┴──────────────────┴───────────────┴─────────────────────────────────┘
    
    
    Or even
    
    ```python
    fields = ["value_decimal_lb", "value_decimal_ub"]
    
    df1 = pl.concat(df.select("cases", v1 = pl.col.value_1.struct.field(x)) for x in fields)
    df2 = pl.concat(df.select("cases", v2 = pl.col.value_2.struct.field(x)) for x in fields)
    
    
    df_result = (
        df1.join(df2, on="cases")
        .select("cases", result = pl.concat_list("v1","v2"))
        .group_by("cases", maintain_order=True)
        .agg("result")
    )
    
    pl.concat([df, df_result], how="align")
    
    shape: (2, 4)
    ┌────────┬──────────────────┬───────────────┬─────────────────────────────────┐
    │ cases  ┆ value_1          ┆ value_2       ┆ result                          │
    │ ---    ┆ ---              ┆ ---           ┆ ---                             │
    │ str    ┆ struct[3]        ┆ struct[3]     ┆ list[list[f64]]                 │
    ╞════════╪══════════════════╪═══════════════╪═════════════════════════════════╡
    │ case_1 ┆ {22.5,22.0,23.0} ┆ {5.0,5.0,5.0} ┆ [[22.0, 5.0], [23.0, 5.0], … [… │
    │ case_2 ┆ {23.0,22.5,23.5} ┆ {5.0,6.0,5.0} ┆ [[22.5, 6.0], [23.5, 6.0], … [… │
    └────────┴──────────────────┴───────────────┴─────────────────────────────────┘