Search code examples
pythonpython-polars

How to create multiple columns in output on when condition in Polars?


I am trying to create 2 new columns in output on checking condition but not sure how to do that.

sample df:

so_df = pl.DataFrame({"low_limit": [1, 3, 0], "high_limit": [3, 4, 2], "value": [0, 5, 1]})
low_limit   high_limit  value
i64 i64 i64
1   3   0
3   4   5
0   2   1

Code for single column creation that works:

so_df.with_columns(pl.when(pl.col('value') > pl.col('high_limit'))
                   .then(pl.lit("High"))
                   .when((pl.col('value') < pl.col('low_limit')))
                   .then(pl.lit("Low"))
                   .otherwise(pl.lit("Within Range")).alias('Flag')
)

output

low_limit   high_limit  value   Flag
i64 i64 i64                     str
1   3   0                       "Low"
3   4   5                       "High"
0   2   1                       "Within Range"

Issue/Doubt: Creating 2 columns that doesn't work

so_df.with_columns(pl.when(pl.col('value') > pl.col('high_limit'))
                   .then(Flag = pl.lit("High"), Normality = pl.lit("Abnormal"))
                   .when((pl.col('value') < pl.col('low_limit')))
                   .then(Flag = pl.lit("Low"), Normality = pl.lit("Abnormal"))
                   .otherwise(Flag = pl.lit("Within Range"), Normality = pl.lit("Normal"))
)

Desired output:

low_limit   high_limit  value   Flag             Normality
i64 i64 i64                     str              str
1   3   0                       "Low"            "Abnormal"
3   4   5                       "High"           "Abnormal"
0   2   1                       "Within Range"   "Normal"

I know I can do another with_Columns and using when-then again but that will take double the computation. So how can I create 2 new columns in 1 go ?

something like:

if (condition):
   Flag = '',
   Normality = ''

Solution

  • You can select into a pl.struct and then extract multiple values out using .struct.field(...):

    df = so_df.with_columns(
        pl.when(pl.col("value") > pl.col("high_limit"))
        .then(pl.struct(Flag=pl.lit("High"), Normality=pl.lit("Abnormal")))
        .when(pl.col("value") < pl.col("low_limit"))
        .then(pl.struct(Flag=pl.lit("Low"), Normality=pl.lit("Abnormal")))
        .otherwise(pl.struct(Flag=pl.lit("Within Range"), Normality=pl.lit("Normal")))
        .struct.field("Flag", "Normality")
    )
    

    Output:

    shape: (3, 5)
    ┌───────────┬────────────┬───────┬──────────────┬───────────┐
    │ low_limit ┆ high_limit ┆ value ┆ Flag         ┆ Normality │
    │ ---       ┆ ---        ┆ ---   ┆ ---          ┆ ---       │
    │ i64       ┆ i64        ┆ i64   ┆ str          ┆ str       │
    ╞═══════════╪════════════╪═══════╪══════════════╪═══════════╡
    │ 1         ┆ 3          ┆ 0     ┆ Low          ┆ Abnormal  │
    │ 3         ┆ 4          ┆ 5     ┆ High         ┆ Abnormal  │
    │ 0         ┆ 2          ┆ 1     ┆ Within Range ┆ Normal    │
    └───────────┴────────────┴───────┴──────────────┴───────────┘