Search code examples
python-polars

How to combine when-then with multi-column conditional computation in Polars


[Modified & extended from another question] My data frame has string columns A, L, G. L & G have a Letter with a 2-digit Number. If the A string is "foo" or "spam", the G string in that row should be changed to it's original Letter + L's original Number, and the L string should be changed to "XX".

df = pl.DataFrame(
    {
        "A": ["foo", "ham", "spam", "egg",],
        "L": ["A54", "A12", "B84", "C12"],
        "G": ["X34", "C84", "G96", "L60",],
    }
)
print(df)

shape: (4, 3)
┌──────┬─────┬─────┐
│ A    ┆ L   ┆ G   │
│ ---  ┆ --- ┆ --- │
│ str  ┆ str ┆ str │
╞══════╪═════╪═════╡
│ foo  ┆ A54 ┆ X34 │
│ ham  ┆ A12 ┆ C84 │
│ spam ┆ B84 ┆ G96 │
│ egg  ┆ C12 ┆ L60 │
└──────┴─────┴─────┘

Expected result:

shape: (4, 3)
┌──────┬─────┬─────┐
│ A    ┆ L   ┆ G   │
│ ---  ┆ --- ┆ --- │
│ str  ┆ str ┆ str │
╞══════╪═════╪═════╡
│ foo  ┆ XX  ┆ X54 │
│ ham  ┆ A12 ┆ C84 │
│ spam ┆ XX  ┆ G84 │
│ egg  ┆ C12 ┆ L60 │
└──────┴─────┴─────┘

Solution

  • As your conditional logic is different per column, you'll want to specify the conditionals separately. You can give df.with_columns multiple arguments to achieve this.

    This solution assumes that the "letter" part is always one character and the "number" is always two characters

    import polars as pl
    from polars.testing import assert_frame_equal
    
    df = pl.DataFrame(
        {
            "A": ["foo", "ham", "spam", "egg",],
            "L": ["A54", "A12", "B84", "C12"],
            "G": ["X34", "C84", "G96", "L60",],
        }
    )
    
    # factor out the predicate to avoid repeating it 
    predicate = pl.when(pl.col("A").is_in(["foo", "spam"]))
    result = (
        df.with_columns(
            predicate.then(pl.lit("XX")).otherwise(pl.col("L")).alias("L"),
            predicate.then(
                pl.col("G").str.slice(0, 1) + 
                pl.col("L").str.slice(1, 2)
            ).otherwise(pl.col("G")),
        )
    )
    
    # test the result
    expected_result = pl.from_repr(
        """
        ┌──────┬─────┬─────┐
        │ A    ┆ L   ┆ G   │
        │ ---  ┆ --- ┆ --- │
        │ str  ┆ str ┆ str │
        ╞══════╪═════╪═════╡
        │ foo  ┆ XX  ┆ X54 │
        │ ham  ┆ A12 ┆ C84 │
        │ spam ┆ XX  ┆ G84 │
        │ egg  ┆ C12 ┆ L60 │
        └──────┴─────┴─────┘
        """
    )
    assert_frame_equal(result, expected_result)