I am trying to buld a function that takes a list of struct columns, extracts two fields, and perform a cross-field combination of all the values of such fields. Everything in the same context. For example, for:
df = pl.DataFrame({
'cases': ['case_1', 'case_2'],
'value_1': [
{'value_decimal': 22.5, 'value_decimal_lb': 22.0, 'value_decimal_ub': 23.0},
{'value_decimal': 23.0, 'value_decimal_lb': 22.5, 'value_decimal_ub': 23.5}
],
'value_2': [
{'value_decimal': 5.0, 'value_decimal_lb': 5.0, 'value_decimal_ub': 5.0},
{'value_decimal': 5.0, 'value_decimal_lb': 6.0, 'value_decimal_ub': 5.0}
]
})
print(df)
shape: (2, 3)
┌────────┬──────────────────┬───────────────┐
│ cases ┆ value_1 ┆ value_2 │
│ --- ┆ --- ┆ --- │
│ str ┆ struct[3] ┆ struct[3] │
╞════════╪══════════════════╪═══════════════╡
│ case_1 ┆ {22.5,22.0,23.0} ┆ {5.0,5.0,5.0} │
│ case_2 ┆ {23.0,22.5,23.5} ┆ {5.0,6.0,5.0} │
└────────┴──────────────────┴───────────────┘
I would like to get:
shape: (2, 4)
┌────────┬──────────────────┬───────────────┬─────────────────────────────────┐
│ cases ┆ value_1 ┆ value_2 ┆ result │
│ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ struct[3] ┆ struct[3] ┆ list[list[f64]] │
╞════════╪══════════════════╪═══════════════╪═════════════════════════════════╡
│ case_1 ┆ {22.5,22.0,23.0} ┆ {5.0,5.0,5.0} ┆ [[22.0, 5.0], [22.0, 5.0], … [… │
│ case_2 ┆ {23.0,22.5,23.5} ┆ {5.0,5.0,5.0} ┆ [[22.5, 5.0], [22.5, 6.0], … [… │
└────────┴──────────────────┴───────────────┴─────────────────────────────────┘
This is something I can achieve nesting map_elements()
:
all_combinations = pl.struct(
[
# Extract the [lb, ub] lists for each column
col.map_elements(lambda x: [x['value_decimal_lb'], x['value_decimal_ub']]).alias(f'{col}_list')
for col in [pl.col('value_1'), pl.col('value_2')]
]
).map_elements(
# Compute all combinations
lambda x: [
list(combo) for combo in product(*[x[f'{col}_list'] for col in [pl.col('value_1'), pl.col('value_2')]])
])
print(df.with_columns(all_combinations))
However, this solution is sub-optimal, and I need to know in advance the return_dtype
of the column, which may differ depending on the input columns.
Would you have any sugggestion on how to implement this solution without using map_elements
?
(
df.with_columns(
result1 = pl.concat_list(pl.col.value_1.struct.field("value_decimal_lb", "value_decimal_ub")),
result2 = pl.concat_list(pl.col.value_2.struct.field("value_decimal_lb", "value_decimal_ub"))
)
.explode("result1")
.explode("result2")
.with_columns(result = pl.concat_list("result1","result2"))
.group_by("cases")
.agg("result", pl.col("value_1","value_2").first())
)
shape: (2, 4)
┌────────┬──────────────────┬───────────────┬─────────────────────────────────┐
│ cases ┆ value_1 ┆ value_2 ┆ result │
│ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ struct[3] ┆ struct[3] ┆ list[list[f64]] │
╞════════╪══════════════════╪═══════════════╪═════════════════════════════════╡
│ case_2 ┆ {23.0,22.5,23.5} ┆ {5.0,6.0,5.0} ┆ [[22.5, 6.0], [22.5, 5.0], … [… │
│ case_1 ┆ {22.5,22.0,23.0} ┆ {5.0,5.0,5.0} ┆ [[22.0, 5.0], [22.0, 5.0], … [… │
└────────┴──────────────────┴───────────────┴─────────────────────────────────┘
Or something like this:
df.with_columns(
result = pl.concat_list(
pl.struct(
f1 = pl.col.value_1.struct.field(x_1),
f2 = pl.col.value_2.struct.field(x_2)
)
for x_1 in ["value_decimal_lb", "value_decimal_ub"]
for x_2 in ["value_decimal_lb", "value_decimal_ub"]
)
)
shape: (2, 4)
┌────────┬──────────────────┬───────────────┬─────────────────────────────────┐
│ cases ┆ value_1 ┆ value_2 ┆ result │
│ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ struct[3] ┆ struct[3] ┆ list[struct[2]] │
╞════════╪══════════════════╪═══════════════╪═════════════════════════════════╡
│ case_1 ┆ {22.5,22.0,23.0} ┆ {5.0,5.0,5.0} ┆ [{22.0,5.0}, {22.0,5.0}, … {23… │
│ case_2 ┆ {23.0,22.5,23.5} ┆ {5.0,6.0,5.0} ┆ [{22.5,6.0}, {22.5,5.0}, … {23… │
└────────┴──────────────────┴───────────────┴─────────────────────────────────┘
Or even
```python
fields = ["value_decimal_lb", "value_decimal_ub"]
df1 = pl.concat(df.select("cases", v1 = pl.col.value_1.struct.field(x)) for x in fields)
df2 = pl.concat(df.select("cases", v2 = pl.col.value_2.struct.field(x)) for x in fields)
df_result = (
df1.join(df2, on="cases")
.select("cases", result = pl.concat_list("v1","v2"))
.group_by("cases", maintain_order=True)
.agg("result")
)
pl.concat([df, df_result], how="align")
shape: (2, 4)
┌────────┬──────────────────┬───────────────┬─────────────────────────────────┐
│ cases ┆ value_1 ┆ value_2 ┆ result │
│ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ struct[3] ┆ struct[3] ┆ list[list[f64]] │
╞════════╪══════════════════╪═══════════════╪═════════════════════════════════╡
│ case_1 ┆ {22.5,22.0,23.0} ┆ {5.0,5.0,5.0} ┆ [[22.0, 5.0], [23.0, 5.0], … [… │
│ case_2 ┆ {23.0,22.5,23.5} ┆ {5.0,6.0,5.0} ┆ [[22.5, 6.0], [23.5, 6.0], … [… │
└────────┴──────────────────┴───────────────┴─────────────────────────────────┘