Search code examples
python-polars

How to transform a series or data frame column with continuous data into a categorical one based on a set of breaks in polars?


A standard way to achieve this on R is by using cut() and in pandas using pd.cut(). This is a straight forward way to covert numerical data into categories than can be defined on the fly (usually referred to as breaks).

To not rewrite the wheel, I will refer to the pandas documentation and R for examples since they are much better than any one I could just improvise now.


Solution

  • Update: Polars now has cut and qcut functions.


    Original answer:

    Polars does not have a cut function, per se. However, one simple (and very performant) way to cut data is using join_asof.

    Let's start with this data.

    import polars as pl
    import numpy as np
    
    sample_size = 10
    df = pl.DataFrame(
        {
            "var1": np.random.default_rng(seed=0).normal(0, 1, sample_size),
        }
    )
    df
    
    shape: (10, 1)
    ┌───────────┐
    │ var1      │
    │ ---       │
    │ f64       │
    ╞═══════════╡
    │ 0.12573   │
    │ -0.132105 │
    │ 0.640423  │
    │ 0.1049    │
    │ -0.535669 │
    │ 0.361595  │
    │ 1.304     │
    │ 0.947081  │
    │ -0.703735 │
    │ -1.265421 │
    └───────────┘
    

    The Algorithm

    Step 1: Create a simple dataset with break points and cat variables

    First, we'll create a DataFrame containing our break points, along with our categorical variable. Let's say we want break points of -1, 0, and 1. We'll supply these in the constructor.

    I'll also use with_row_index to automatically generate our categorical values. (If you like, you can assign something else for a categorical variable in the constructor.)

    Note that the data type of your break points must match the data type of the variable that you are cutting. Hence, in this simple example, I'm writing the first break point as "-1.0" (so that Polars automatically creates break_pt as Float64, rather than an integer.)

    break_df = pl.DataFrame(
        {
            "break_pt": [-1.0, 0, 1],
        }
    ).with_row_count("binned")
    break_df
    
    shape: (3, 2)
    ┌────────┬──────────┐
    │ binned ┆ break_pt │
    │ ---    ┆ ---      │
    │ u32    ┆ f64      │
    ╞════════╪══════════╡
    │ 0      ┆ -1.0     │
    │ 1      ┆ 0.0      │
    │ 2      ┆ 1.0      │
    └────────┴──────────┘
    

    Step 2: Join using join_asof

    Now we can perform the join_asof. Note that both datasets must be sorted by the as_of keys, so we'll need to sort our DataFrame of random numbers by our continuous variable (var1) before the join. (break_df is already sorted.).

    (
        df
        .sort("var1")
        .join_asof(
            break_df,
            left_on="var1",
            right_on="break_pt",
            strategy="forward",
        )
    )
    
    shape: (10, 3)
    ┌───────────┬────────┬──────────┐
    │ var1      ┆ binned ┆ break_pt │
    │ ---       ┆ ---    ┆ ---      │
    │ f64       ┆ u32    ┆ f64      │
    ╞═══════════╪════════╪══════════╡
    │ -1.265421 ┆ 0      ┆ -1.0     │
    │ -0.703735 ┆ 1      ┆ 0.0      │
    │ -0.535669 ┆ 1      ┆ 0.0      │
    │ -0.132105 ┆ 1      ┆ 0.0      │
    │ 0.1049    ┆ 2      ┆ 1.0      │
    │ 0.12573   ┆ 2      ┆ 1.0      │
    │ 0.361595  ┆ 2      ┆ 1.0      │
    │ 0.640423  ┆ 2      ┆ 1.0      │
    │ 0.947081  ┆ 2      ┆ 1.0      │
    │ 1.304     ┆ null   ┆ null     │
    └───────────┴────────┴──────────┘
    

    This leaves us with the last bin (the values above the last breakpoint) as null. To fill these null values with a proper binned value, we can use fill_null.

    (
        df
        .sort("var1")
        .join_asof(
            break_df,
            left_on="var1",
            right_on="break_pt",
            strategy="forward",
        )
        .with_columns(pl.col("binned").fill_null(pl.col("binned").max() + 1))
    )
    
    shape: (10, 3)
    ┌───────────┬────────┬──────────┐
    │ var1      ┆ binned ┆ break_pt │
    │ ---       ┆ ---    ┆ ---      │
    │ f64       ┆ u32    ┆ f64      │
    ╞═══════════╪════════╪══════════╡
    │ -1.265421 ┆ 0      ┆ -1.0     │
    │ -0.703735 ┆ 1      ┆ 0.0      │
    │ -0.535669 ┆ 1      ┆ 0.0      │
    │ -0.132105 ┆ 1      ┆ 0.0      │
    │ 0.1049    ┆ 2      ┆ 1.0      │
    │ 0.12573   ┆ 2      ┆ 1.0      │
    │ 0.361595  ┆ 2      ┆ 1.0      │
    │ 0.640423  ┆ 2      ┆ 1.0      │
    │ 0.947081  ┆ 2      ┆ 1.0      │
    │ 1.304     ┆ 3      ┆ null     │
    └───────────┴────────┴──────────┘
    

    Performance

    So how well does this perform? Let's increase our random sample to 100 million values. And let's expand our breaklist to ~6,000 break points.

    sample_size = 100_000_000
    df = pl.DataFrame(
        {
            "var1": np.random.default_rng(seed=0).normal(0, 1, sample_size),
        }
    )
    df
    
    shape: (100000000, 1)
    ┌───────────┐
    │ var1      │
    │ ---       │
    │ f64       │
    ╞═══════════╡
    │ 0.1257    │
    │ -0.132105 │
    │ 0.640423  │
    │ 0.1049    │
    │ …         │
    │ -0.714924 │
    │ 0.269947  │
    │ -2.3158   │
    │ -0.383743 │
    └───────────┘
    
    break_list = [next_val / 1000 for next_val in range(-3000, 3001)]
    
    break_df = pl.DataFrame(
        {
            "break_pt": break_list,
        }
    ).with_row_index("binned")
    break_df
    
    shape: (6001, 2)
    ┌────────┬──────────┐
    │ binned ┆ break_pt │
    │ ---    ┆ ---      │
    │ u32    ┆ f64      │
    ╞════════╪══════════╡
    │ 0      ┆ -3.0     │
    │ 1      ┆ -2.999   │
    │ 2      ┆ -2.998   │
    │ 3      ┆ -2.997   │
    │ …      ┆ …        │
    │ 5997   ┆ 2.997    │
    │ 5998   ┆ 2.998    │
    │ 5999   ┆ 2.999    │
    │ 6000   ┆ 3.0      │
    └────────┴──────────┘
    

    And now timing the algorithm itself...

    import time
    start = time.perf_counter()
    (
        df.sort("var1")
        .join_asof(
            break_df,
            left_on="var1",
            right_on="break_pt",
            strategy="forward",
        )
        .with_columns(pl.col("binned").fill_null(pl.col("binned").max() + 1))
    )
    print(time.perf_counter() - start)
    
    shape: (100000000, 3)
    ┌───────────┬────────┬──────────┐
    │ var1      ┆ binned ┆ break_pt │
    │ ---       ┆ ---    ┆ ---      │
    │ f64       ┆ i64    ┆ f64      │
    ╞═══════════╪════════╪══════════╡
    │ -5.666706 ┆ 0      ┆ -3.0     │
    │ -5.6048   ┆ 0      ┆ -3.0     │
    │ -5.428571 ┆ 0      ┆ -3.0     │
    │ -5.350106 ┆ 0      ┆ -3.0     │
    │ …         ┆ …      ┆ …        │
    │ 5.327897  ┆ 6001   ┆ null     │
    │ 5.344677  ┆ 6001   ┆ null     │
    │ 5.386379  ┆ 6001   ┆ null     │
    │ 5.4829    ┆ 6001   ┆ null     │
    └───────────┴────────┴──────────┘
    >>> print(time.perf_counter() - start)
    3.107058142999449
    

    Just over 3 seconds.

    Edit: Some helpful additions/improvements

    Some improvements that we can make:

    • adding the capacity for labels, either default labels or a specified list of labels
    • returning the labels as a categorical variable
    • allowing the list of cut values to be passed as integers
    • encapsulating the algorithm into a callable function
    from typing import List
    def cut_dataframe(_df: pl.DataFrame,
                      var_nm: str,
                      bins: List[float],
                      labels: List[str] = None) -> pl.DataFrame:
    
        cuts_df = pl.DataFrame([
            pl.Series(
                name="break_pt",
                values=bins,
                dtype=pl.Float64
            ).extend_constant(np.Inf, 1)
        ])
    
        if labels:
            cuts_df = cuts_df.with_column(
                pl.Series(
                    name="category",
                    values=labels
                )
            )
        else:
            cuts_df = cuts_df.with_columns(
                pl.format(
                    "({}, {}]",
                    pl.col("break_pt").shift(1, fill_value=-np.Inf),
                    pl.col("break_pt"),
                )
                .alias("category")
            )
    
        cuts_df = cuts_df.with_columns(pl.col("category").cast(pl.Categorical))
    
        result = (
            _df.sort([var_nm]).join_asof(
                cuts_df,
                left_on=var_nm,
                right_on="break_pt",
                strategy="forward",
            )
        )
        return result
    

    We can now cut our DataFrame with a single call.

    Below, we'll call the function without specifying any labels, allowing the function to create default labels. And notice that we don't have to worry about our list of break points being floats - the function will automatically cast the values to pl.Float64.

    cut_dataframe(df, "var1", [-1, 1])
    
    shape: (10, 3)
    ┌───────────┬──────────┬──────────────┐
    │ var1      ┆ break_pt ┆ category     │
    │ ---       ┆ ---      ┆ ---          │
    │ f64       ┆ f64      ┆ cat          │
    ╞═══════════╪══════════╪══════════════╡
    │ -1.265421 ┆ -1.0     ┆ (-inf, -1.0] │
    │ -0.703735 ┆ 1.0      ┆ (-1.0, 1.0]  │
    │ -0.535669 ┆ 1.0      ┆ (-1.0, 1.0]  │
    │ -0.132105 ┆ 1.0      ┆ (-1.0, 1.0]  │
    │ 0.1049    ┆ 1.0      ┆ (-1.0, 1.0]  │
    │ 0.1257    ┆ 1.0      ┆ (-1.0, 1.0]  │
    │ 0.361595  ┆ 1.0      ┆ (-1.0, 1.0]  │
    │ 0.640423  ┆ 1.0      ┆ (-1.0, 1.0]  │
    │ 0.947081  ┆ 1.0      ┆ (-1.0, 1.0]  │
    │ 1.304     ┆ inf      ┆ (1.0, inf]   │
    └───────────┴──────────┴──────────────┘
    

    And here, we'll pass in a list of labels.

    cut_dataframe(df, "var1", [-1, 1], ["low", "med", "hi"])
    
    shape: (10, 3)
    ┌───────────┬──────────┬──────────┐
    │ var1      ┆ break_pt ┆ category │
    │ ---       ┆ ---      ┆ ---      │
    │ f64       ┆ f64      ┆ cat      │
    ╞═══════════╪══════════╪══════════╡
    │ -1.265421 ┆ -1.0     ┆ low      │
    │ -0.703735 ┆ 1.0      ┆ med      │
    │ -0.535669 ┆ 1.0      ┆ med      │
    │ -0.132105 ┆ 1.0      ┆ med      │
    │ 0.1049    ┆ 1.0      ┆ med      │
    │ 0.1257    ┆ 1.0      ┆ med      │
    │ 0.361595  ┆ 1.0      ┆ med      │
    │ 0.640423  ┆ 1.0      ┆ med      │
    │ 0.947081  ┆ 1.0      ┆ med      │
    │ 1.304     ┆ inf      ┆ hi       │
    └───────────┴──────────┴──────────┘
    

    Hopefully, the above is more helpful.