Search code examples
pythonpython-polars

How to set masked values within each group in group_by context using py-polars


Update: Setting the mask in a group_by context now works as expected.


Since rank does not handle null values, I want to write a rank function that can handle null values.

import numpy as np
import polars as pl

df = pl.DataFrame({
    'group': ['a'] * 3 + ['b'] * 3,
    'value': [2, 1, None, 4, 5, 6],
})
df
shape: (6, 2)
┌───────┬───────┐
│ group ┆ value │
│ ---   ┆ ---   │
│ str   ┆ i64   │
╞═══════╪═══════╡
│ a     ┆ 2     │
│ a     ┆ 1     │
│ a     ┆ null  │
│ b     ┆ 4     │
│ b     ┆ 5     │
│ b     ┆ 6     │
└───────┴───────┘

It works well if I didn't use group_by since I can use when-then-otherwise to set values.

def valid_rank(expr: pl.Expr, descending=False):
    """handle null values when rank"""
    FLOAT_MAX, FLOAT_MIN = np.finfo(float).max, np.finfo(float).min
    mask = expr.is_null()
    expr = expr.fill_null(FLOAT_MIN) if descending else expr.fill_null(FLOAT_MAX) 
    return pl.when(~mask).then(expr.rank(descending=descending)).otherwise(None)

df.with_columns(valid_rank(pl.col('value')))
shape: (6, 2)
┌───────┬───────┐
│ group ┆ value │
│ ---   ┆ ---   │
│ str   ┆ f32   │
╞═══════╪═══════╡
│ a     ┆ 2.0   │
│ a     ┆ 1.0   │
│ a     ┆ null  │
│ b     ┆ 3.0   │
│ b     ┆ 4.0   │
│ b     ┆ 5.0   │
└───────┴───────┘

However, in group_by context, the predicate col("value").is_not_null() in when->then->otherwise is not an aggregation so I will get

ComputeError: the predicate 'not(col("value").is_null())' in 'when->then->otherwise' is not a valid aggregation and might produce a different number of rows than the group_by operation would

Usually I have to make some calculations within each group after rank and I am worried about performance if I use partition_by to split the DataFrame. So I hope that Polars can have expressions like np.putmask or similar functions that can set values within each group.

def valid_rank(expr: pl.Expr, descending=False):
    """handle null values when rank"""
    FLOAT_MAX, FLOAT_MIN = np.finfo(float).max, np.finfo(float).min
    mask = expr.is_null()
    expr = expr.fill_null(FLOAT_MIN) if descending else expr.fill_null(FLOAT_MAX)
    # return pl.putmask(expr.rank(descending=descending), mask, None)  # hope
    # return expr.rank(descending).set(mask, None)  # hope

Solution

  • I propose a solution that is minimally invasive to existing code, requires no changes to the Polars API, and allows masking for a wide variety of Expressions.

    Decorator: Maskable

    The decorator below is one easy way to add masking capabilities to any suitable Expression. The decorator adds two keyword-only parameters to any Expression: mask and mask_fill.

    If mask=None (the default), the decorator passes all parameters to the decorated Expression unaltered. There are no changes needed to existing code for this.

    If a mask is provided, then the decorator handles the tasking of masking, filtering, recombining, and sorting.

    Here's the documentation and code for the decorator. The documentation is simply from my docstring of the function. (It helps me track what I'm doing if I keep the docstring with the function as I write code.)

    (I suggest skipping directly to the Examples section first, then coming back to look at the code and documentation.)

    Overview

    from functools import wraps
    
    import polars as pl
    
    def maskable(expr: pl.Expr) -> pl.Expr:
        """
        Allow masking of values in an Expression
    
        This function is intended to be used as a decorator for Polars Expressions.
        For example:
            pl.Expr.rolling_mean = maskable(pl.Expr.rolling_mean)
    
        The intended purpose of this decorator is to change the way that an Expression
        handles exceptional values (e.g., None, NaN, Inf, -Inf, zero, negative values, etc.)
    
        Usage Notes:
        This decorator should only be applied to Expressions whose return value is the
        same length as its input (e.g., rank, rolling_mean, ewm_mean, pct_change).
        It is not intended for aggregations (e.g., sum, var, count).  (For aggregations,
        use "filter" before the aggregration Expression.)
    
        Performance Notes:
        This decorator adds significant overhead to a function call when a mask is supplied.
        As such, this decorator should not be used in places where other methods would
        suffice (e.g., filter, when/then/otherwise, fill_null, etc.)
    
        In cases where no mask is supplied, the overhead of this decorator is insignicant.
    
        Operation
        ---------
        A mask is (conceptually) a column/expession/list of boolean values that control
        which values will not be passed to the wrapped expression:
    
                True, Null -> corresponding value will not be passed to the wrapped
                expression, and will instead be filled by the mask_fill value after
                the wrapped expression has been evaluated.
    
                False -> corresponding value will be passed to the wrapped expression.
    """
    

    Parameters

    """
        Parameters
        ----------
        The decorator will add two keyword-only parameters to any wrapped Expression:
    
        mask
    
            In-Stream Masks
            ---------------
            In-stream masks select a mask based on the current state of a chained expression
            at the point where the decorated expression is called.  (See examples below)
    
            str -> One of {"Null", "NaN", "-Inf", "+Inf"}
    
            list[str] -> two or more of the above, all of which will be filled with the same
                        mask_fill value
    
            Static Masks
            ------------
            Static masks select a mask at the time the context is created, and do not reflect
            changes in values as a chained set of expressions is evaluated (see examples below)
    
            list[bool] -> external list of boolean values to use as mask
    
            pl.Series -> external Series to use as mask
    
            pl.Expr -> ad-hoc expression that evaluates to boolean
    
            Note: for static masks, it is the responsibility of the caller to ensure that the
            mask is the same length as the number of values to which it applies.
    
            No Mask
            -------
            None -> no masking applied.  The decorator passses all parameters and values to the
                    wrapped expression unaltered.  There is no significant performance penalty.
    
        mask_fill
            Fill value to be used for all values that are masked.
    
    
    """
    

    The Decorator Code

    Here is the code for the decorator itself.

    from functools import wraps
    
    import polars as pl
    
    def maskable(expr: pl.Expr) -> pl.Expr:
        @wraps(expr)
        def maskable_expr(
            self: pl.Expr,
            *args,
            mask: str | list[str] | list[bool] | pl.Series | pl.Expr | None = None,
            mask_fill: float | int | str | bool | None = None,
            **kwargs,
        ):
    
            if mask is None:
                return expr(self, *args, **kwargs)
    
            if isinstance(mask, str):
                mask = [mask]
    
            if isinstance(mask, list):
                if len(mask) == 0:
                    return expr(self, *args, **kwargs)
                if isinstance(mask[0], bool):
                    mask = pl.Series(mask)
                elif isinstance(mask[0], str):
                    mask_dict = {
                        "Null": (self.is_null()),
                        "NaN": (self.is_not_null() & self.is_nan()),
                        "+Inf": (self.is_not_null() & self.is_infinite() & (self > 0)),
                        "-Inf": (self.is_not_null() & self.is_infinite() & (self < 0)),
                    }
    
                    mask_str, *mask_list = mask
                    mask = mask_dict[mask_str]
                    while mask_list:
                        mask_str, *mask_list = mask_list
                        mask = mask | mask_dict[mask_str]
    
            if isinstance(mask, pl.Series):
                mask = pl.lit(mask)
    
            mask = mask.fill_null(True)
    
            return (
                expr(self.filter(mask.not_()), *args, **kwargs)
                .append(pl.repeat(mask_fill, mask.sum()))
                .sort_by(mask.arg_sort())
            )
    
        return maskable_expr
    

    Examples

    The following are examples of usage from the docstring that resides in my library for this decorator function. (It helps me track which use cases that I've tested.)

    Simple in-stream mask

    Here's an example of a simple "in-stream" mask, based on your Stack Overflow question. The mask prevents null values from disturbing the ranking. The mask is calculated at the time that the wrapped Expression (rank) receives the data.

    Note that the changes to the code are not terribly invasive. There's no new expression, no new evaluation context required, and no changes to the Polars API. All work is done by the decorator.

    Also, note that there's no when/then/otherwise needed to achieve this; thus, the over grouping expression does not complain.

    import polars as pl
    
    pl.Expr.rank = maskable(pl.Expr.rank)
    
    df = pl.DataFrame(
        {
            "group": ["a"] * 4 + ["b"] * 4,
            "a": [1, 2, None, 3, None, 1, None, 2],
        }
    )
    
    (
        df.with_columns(
            pl.col("a")
                .rank()
                .over("group")
                .alias("rank_a"),
            pl.col("a")
                .rank(mask='Null', mask_fill=float("NaN"))
                .over("group")
                .alias("rank_a_masked"),
        )
    )
    
    shape: (8, 4)
    ┌───────┬──────┬────────┬───────────────┐
    │ group ┆ a    ┆ rank_a ┆ rank_a_masked │
    │ ---   ┆ ---  ┆ ---    ┆ ---           │
    │ str   ┆ i64  ┆ f32    ┆ f64           │
    ╞═══════╪══════╪════════╪═══════════════╡
    │ a     ┆ 1    ┆ 2.0    ┆ 1.0           │
    │ a     ┆ 2    ┆ 3.0    ┆ 2.0           │
    │ a     ┆ null ┆ 1.0    ┆ NaN           │
    │ a     ┆ 3    ┆ 4.0    ┆ 3.0           │
    │ b     ┆ null ┆ 1.5    ┆ NaN           │
    │ b     ┆ 1    ┆ 3.0    ┆ 1.0           │
    │ b     ┆ null ┆ 1.5    ┆ NaN           │
    │ b     ┆ 2    ┆ 4.0    ┆ 2.0           │
    └───────┴──────┴────────┴───────────────┘
    

    Multiple Masked values

    This is an example of a convenience built-in: multiple exceptional values can be provided in a list. Note that masked values all receive the same fill_mask value.

    This example also shows the mask working in Lazy mode, one side-benefit of using a decorator approach.

    import polars as pl
    
    pl.Expr.rolling_mean = maskable(pl.Expr.rolling_mean)
    
    df = pl.DataFrame(
        {
            "a": [1.0, 2, 3, float("NaN"), 4, None, float("NaN"), 5],
        }
    ).lazy()
    
    (
        df.with_columns(
            pl.col("a")
                .rolling_mean(window_size=2).alias("roll_mean"),
            pl.col("a")
                .rolling_mean(window_size=2, mask=['NaN', 'Null'], mask_fill=None)
                .alias("roll_mean_masked"),
        ).collect()
    )
    
    shape: (8, 3)
    ┌──────┬───────────┬──────────────────┐
    │ a    ┆ roll_mean ┆ roll_mean_masked │
    │ ---  ┆ ---       ┆ ---              │
    │ f64  ┆ f64       ┆ f64              │
    ╞══════╪═══════════╪══════════════════╡
    │ 1.0  ┆ null      ┆ null             │
    │ 2.0  ┆ 1.5       ┆ 1.5              │
    │ 3.0  ┆ 2.5       ┆ 2.5              │
    │ NaN  ┆ NaN       ┆ null             │
    │ 4.0  ┆ NaN       ┆ 3.5              │
    │ null ┆ null      ┆ null             │
    │ NaN  ┆ null      ┆ null             │
    │ 5.0  ┆ NaN       ┆ 4.5              │
    └──────┴───────────┴──────────────────┘
    

    In-stream versus Static masks

    The code below provides an example of the difference between an "in-stream" mask and a "static" mask.

    An in-stream mask makes its masking choices at the time the wrapped expression is executed. This includes the evaluated results of all chained expressions that came before it.

    By contrast, a static mask makes its masking choices when the context is created, and it never changes.

    For most use cases, in-stream masks and static masks will produce the same result. The example below is one example where they will not.

    The sqrt function creates new NaN values during the evaluation of the chained expression. The in-stream mask sees these; the static mask sees column a only as it exists at the time the with_columns context is initiated.

    import polars as pl
    
    pl.Expr.ewm_mean = maskable(pl.Expr.ewm_mean)
    
    df = pl.DataFrame(
        {
            "a": [1.0, 2, -2, 3, -4, 5, 6],
        }
    )
    
    (
        df.with_columns(
            pl.col("a").sqrt().alias('sqrt'),
            pl.col('a').sqrt()
                .ewm_mean(half_life=4, mask="NaN", mask_fill=None)
                .alias("ewm_instream"),
            pl.col("a").sqrt()
                .ewm_mean(half_life=4, mask=pl.col('a').is_nan(), mask_fill=None)
                .alias("ewm_static"),
            pl.col("a").sqrt()
                .ewm_mean(half_life=4).alias('ewm_no_mask'),
        )
    )
    
    shape: (7, 5)
    ┌──────┬──────────┬──────────────┬────────────┬─────────────┐
    │ a    ┆ sqrt     ┆ ewm_instream ┆ ewm_static ┆ ewm_no_mask │
    │ ---  ┆ ---      ┆ ---          ┆ ---        ┆ ---         │
    │ f64  ┆ f64      ┆ f64          ┆ f64        ┆ f64         │
    ╞══════╪══════════╪══════════════╪════════════╪═════════════╡
    │ 1.0  ┆ 1.0      ┆ 1.0          ┆ 1.0        ┆ 1.0         │
    │ 2.0  ┆ 1.414214 ┆ 1.225006     ┆ 1.225006   ┆ 1.225006    │
    │ -2.0 ┆ NaN      ┆ null         ┆ NaN        ┆ NaN         │
    │ 3.0  ┆ 1.732051 ┆ 1.424003     ┆ NaN        ┆ NaN         │
    │ -4.0 ┆ NaN      ┆ null         ┆ NaN        ┆ NaN         │
    │ 5.0  ┆ 2.236068 ┆ 1.682408     ┆ NaN        ┆ NaN         │
    │ 6.0  ┆ 2.4494   ┆ 1.892994     ┆ NaN        ┆ NaN         │
    └──────┴──────────┴──────────────┴────────────┴─────────────┘
    

    Incorporating external masks

    Sometimes we want to mask values based on the results of external inputs, for example in code testing, sensitivity testing, or incorporating results from external libraries/functions. External lists are, by definition, static masks. And it is up to the user to make sure that they are the correct length to match the column that they are masking.

    The example below also demonstrates that the scope of a mask (in-stream or static) is limited to one expression evaluation. The mask does not stay in effect for other expressions in a chained expression. (However, you can certainly declare masks for other expressions in a single chain.) In the example below, diff does not see the mask that was used for the prior rank step.

    import polars as pl
    
    pl.Expr.rank = maskable(pl.Expr.rank)
    pl.Expr.diff = maskable(pl.Expr.diff)
    
    df = pl.DataFrame(
        {
            "trial_nbr": [1, 2, 3, 4, 5, 6],
            "response": [1.0, -5, 9, 3, 2, 10],
        }
    )
    
    pending = [False, True, False, False, False, False]
    (
        df.with_columns(
            pl.col("response").rank().alias('rank'),
            pl.col("response")
                .rank(mask=pending, mask_fill=float("NaN"))
                .alias('rank_masked'),
            pl.col("response")
                .rank(mask=pending, mask_fill=float("NaN"))
                .diff()
                .alias('diff_rank'),
        )
    )
    
    shape: (6, 5)
    ┌───────────┬──────────┬──────┬─────────────┬───────────┐
    │ trial_nbr ┆ response ┆ rank ┆ rank_masked ┆ diff_rank │
    │ ---       ┆ ---      ┆ ---  ┆ ---         ┆ ---       │
    │ i64       ┆ f64      ┆ f32  ┆ f64         ┆ f64       │
    ╞═══════════╪══════════╪══════╪═════════════╪═══════════╡
    │ 1         ┆ 1.0      ┆ 2.0  ┆ 1.0         ┆ null      │
    │ 2         ┆ -5.0     ┆ 1.0  ┆ NaN         ┆ NaN       │
    │ 3         ┆ 9.0      ┆ 5.0  ┆ 4.0         ┆ NaN       │
    │ 4         ┆ 3.0      ┆ 4.0  ┆ 3.0         ┆ -1.0      │
    │ 5         ┆ 2.0      ┆ 3.0  ┆ 2.0         ┆ -1.0      │
    │ 6         ┆ 10.0     ┆ 6.0  ┆ 5.0         ┆ 3.0       │
    └───────────┴──────────┴──────┴─────────────┴───────────┘
    

    map_elements

    This approach also works with map_elements (but currently only when map_elements is used with only one column input, not when a struct is used to pass multiple values to map_elements).

    For example, the simple function below will throw an exception if a value greater than 1.0 is passed to my_func. Normally, this would halt execution, and some kind of workaround would be needed, such as setting the value to something else, and remembering to set it's value back after map_elements is run. Using a mask, you can side-step the problem conveniently, without such a workaround.

    import polars as pl
    import math
    
    pl.Expr.map_elements = maskable(pl.Expr.map_elements)
    
    def my_func(value: float) -> float:
        return math.acos(value)
    
    df = pl.DataFrame(
        {
            "val": [0.0, 0.5, 0.7, 0.9, 1.0, 1.1],
        }
    )
    
    (
        df.with_columns(
            pl.col('val')
            .map_elements(function=my_func,
                   mask=pl.col('val') > 1.0,
                   mask_fill=float('NaN')
                   )
            .alias('result')
        )
    )
    
    shape: (6, 2)
    ┌─────┬──────────┐
    │ val ┆ result   │
    │ --- ┆ ---      │
    │ f64 ┆ f64      │
    ╞═════╪══════════╡
    │ 0.0 ┆ 1.570796 │
    │ 0.5 ┆ 1.047198 │
    │ 0.7 ┆ 0.795399 │
    │ 0.9 ┆ 0.451027 │
    │ 1.0 ┆ 0.0      │
    │ 1.1 ┆ NaN      │
    └─────┴──────────┘
    

    The Algorithm

    The heart of the algorithm is these few lines:

    expr(self.filter(mask.not_()), *args, **kwargs)
    .append(pl.repeat(mask_fill, mask.sum()))
    .sort_by(mask.arg_sort())
    

    In steps,

    • The algorithm filters the results of the current state of the chained expression based on the mask, and passes the filtered results to the wrapped expression for evaluation.
    • The column of returned values from the evaluated expression is then extended to its former length by filling with the mask_fill values.
    • An argsort on the mask is then used to restore the filled values at the bottom to their place among the returned values.

    This last step assumes that the filter step maintains the relative ordering of rows (which it does), and that the mask_fill values are indistinguishable/identical (which they are).

    Benefits and Limitations

    Using this approach has some notable benefits:

    • The impact to code is minimal. No complex workarounds are needed (e.g., partitioning DataFrames, changing values)
    • There is zero impact to the Polars API. No new expressions. No new context. No new keywords.
    • Decorated Expressions continue to run in parallel. The Python code in the decorator merely writes expressions and passes them along; the Python code itself does not run calculations on data.
    • Decorated Expressions retain their familiar names and signatures, with the exception of two additional keyword-only parameters, which default to no-masking.
    • Decorated Expressions work in both Lazy and Eager mode.
    • Decorated Expressions can be used just like any other Expression, including chaining Expressions and using over for grouping.
    • The performance impact when a decorated Expression is used without masking is insignificant. The decorator merely passes the parameters to the wrapped Expression unaltered.

    Some limitations do apply:

    • The coding hints (as they are stated above) may raise errors with linters and IDE's when using decorated Expressions. Some linters will complain that mask and mask_fill are not valid parameters.
    • Not all Expressions are suitable for masking. Masking will not work for aggregation expressions, in particular. (Nor should they; simple filtering before an aggregating expression will be far faster than masking.)

    Performance Impact

    Using a mask with an Expression will impact performance. The additional runtime is associated with filtering based on the mask and then sorting to place the mask_fill values back to their proper place in the results. This last step requires sorting, which is O(n log n), in general.

    The performance overhead is more or less independent of the expression that is wrapped by the decorator. Instead, the performance impact is a function of the number of records involved, due to the filtering and the sorting steps.

    Whether the performance impact outweighs the convenience of this approach is probably better discussed on GitHub (depending on whether this approach is acceptable).

    And there may be ways to reduce the O(n log n) complexity at the heart of the algorithm, if the performance impact is deemed too severe. I tried an approach that interleaves the results returned from the wrapped function with the fill values, based on the mask, but it performed no better than the simple sort that is shown above. Perhaps there is a way to interleave the two in a more performant manner.

    I would point out one thing, though. Masking will come with a performance cost (no matter what approach is used). Thus, comparing 'no-masking' to 'masking' may not be terribly informative. Instead, 'masking' accomplished with one algorithm versus another is probably the better comparison.