Search code examples
pythonpython-polarspandera

Optimizing pandera polars validation check function


I'm testing out switching to polars from pandas, and running into performance issues I wasn't expecting. Hoping this is just an issue of not knowing the really optimized lazyframe way of validating data.

Here is one of the checks I'm noticing relatively significant differences between polars checks and pandas:

def has_no_conditional_field_conflict(
    grouped_data: pa.PolarsData,
    condition_values: set[str] = {"977"},
    groupby_fields: str = "", 
    separator: str = ";",
) -> pl.LazyFrame:
    start = datetime.now()

    lf = grouped_data.lazyframe
    check_col = pl.col(groupby_fields).str.split(separator)
    val_col = pl.col(grouped_data.key).str.strip()
    
    check_results = (
        (check_col.apply(lambda arr: set(arr).isdisjoint(condition_values), return_dtype=pl.Boolean)) &
        (val_col == "")
    ) | (
        (check_col.apply(lambda arr: not set(arr).isdisjoint(condition_values), return_dtype=pl.Boolean)) &
        (val_col != "")
    )

    rf = lf.with_columns(check_results.alias("check_results")).select("check_results").collect()
    print(f"Processing of has_no_conditional_field_conflict took {(datetime.now() - start).total_seconds()} seconds")
    return rf.lazy()

In polars, this function is taking on average ~0.1 seconds, and is used for many fields (called 41 times during a validation run). The overall validation time for 10,000 entries is taking about 8.5 seconds. If I remove the .collect() and just pass back the lazyframe with expressions, the total processing of the function itself is about 0.0007 seconds, but then the overall validation run takes about 13 seconds.

When running in pandas using groupby and iterating over the groupby data (which with pandera pandas checks provides a dict[value, series]) over the same data set I see check function times of 0.008 seconds and an overall validation of 6 seconds.

Is there a more optimized way that I can use polars in this check? I know apply is generally not favored over large dataframes but I haven't been able to figure out a better way to achieve what the check needs. I've been able to make significant improvements in other places, but this thing seems to be my current bottleneck.

Update: The purpose of this is to check if Field A's (check_col) value is within a condition_values set. If Field A is in the set, then Field B (val_col) must not be blank. If Field A is not in the set (hence the isdisjoint), then Field B must be blank. Field A can either be a single value or a semicolon separated string of values. So for example Field A might be 900;977. Condition_value defaults to {977} but has the potential to be a set of values.


Solution

  • Polars does not yet have an explicit isdisjoint() function.

    - https://github.com/pola-rs/polars/issues/9908

    As a workaround, you can check the length of the list.set_intersection

    df = pl.DataFrame({
        "check_col": ["900;977", "1;977;2", "", "123"], 
        "val_col": ["foo", "", "bar", ""]
    })
    
    condition_values = {"977", "1"}
    
    df.with_columns(is_disjoint = 
       pl.col.check_col.str.split(";").list.set_intersection(list(condition_values))
         .list.len() == 0
    )
    
    shape: (4, 3)
    ┌───────────┬─────────┬─────────────┐
    │ check_col ┆ val_col ┆ is_disjoint │
    │ ---       ┆ ---     ┆ ---         │
    │ str       ┆ str     ┆ bool        │
    ╞═══════════╪═════════╪═════════════╡
    │ 900;977   ┆ foo     ┆ false       │
    │ 1;977;2   ┆         ┆ false       │
    │           ┆ bar     ┆ true        │
    │ 123       ┆         ┆ true        │
    └───────────┴─────────┴─────────────┘
    
    • note: Polars does not currently accept set objects, so we explictly call list() when passing in condition_values