Search code examples
python-polarscut

Using Polars cut with label from a Pandas perspective


I'm migrating some code from Pandas to Polars. I try to use cut from polars but there are differences (no bin so I have to calculate it).

But I still don't understand the result with label in polars.

I have to use more labels than I want to get the same result than pandas.

import numpy as np
import pandas as pd
import polars as pl

# Exemple de DataFrame Polars
data = {
    "value": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
}
df_pl = pl.DataFrame(data)

# Convertir en DataFrame Pandas pour obtenir les breakpoints
df_pd = df_pl.to_pandas()

# Use returbins to get the breakpoints (from pandas)
df_pd["cut_label_pd"], breakpoints = pd.cut(df_pd["value"], 4, labels=["low", "medium", "hight", "very high"], retbins=True)
print(pl.from_pandas(df_pd))
shape: (10, 2)
┌───────┬──────────────┐
│ value ┆ cut_label_pd │
│ ---   ┆ ---          │
│ i64   ┆ cat          │
╞═══════╪══════════════╡
│ 1     ┆ low          │
│ 2     ┆ low          │
│ 3     ┆ low          │
│ 4     ┆ medium       │
│ 5     ┆ medium       │
│ 6     ┆ hight        │
│ 7     ┆ hight        │
│ 8     ┆ very high    │
│ 9     ┆ very high    │
│ 10    ┆ very high    │
└───────┴──────────────┘

print(breakpoints)
# [ 0.991  3.25   5.5    7.75  10.   ]

Is there is a better way? (notice the value in labels in polars cut)

# Cut in polars
labels = ["don't use it", "low", "medium", "hight", "very high", "don't use it too"] 
df_pl = df_pl.with_columns(
    pl.col("value").cut(breaks=breakpoints, labels=labels).alias("cut_label_pl")
)

print(df_pl)
shape: (10, 2)
┌───────┬──────────────┐
│ value ┆ cut_label_pl │
│ ---   ┆ ---          │
│ i64   ┆ cat          │
╞═══════╪══════════════╡
│ 1     ┆ low          │
│ 2     ┆ low          │
│ 3     ┆ low          │
│ 4     ┆ medium       │
│ 5     ┆ medium       │
│ 6     ┆ hight        │
│ 7     ┆ hight        │
│ 8     ┆ very high    │
│ 9     ┆ very high    │
│ 10    ┆ very high    │
└───────┴──────────────┘

Solution

  • The short answer is that Polars doesn't need as many breaks as pandas retbins argument produces. The Polars docstring for labels states that "The number of labels must be equal to the number of cut points plus one." Since we have 4 labels, we need 3 breaks. Polars does not need the first or the last one produced by pandas.

    Instead of adding fake labels, just reduce the number of breaks. You could change your existing code from pl.col("value").cut(breaks=breakpoints, ...) to pl.col("value").cut(breaks=breakpoints[1:-1], ...) , then remove the two "don't use it" labels and it would be a bit nicer.

    But obviously you don't want to depend on pandas just to calculate some evenly spaced bins, so lets do it ourselves!

    Starting with a baseline:

    data = {"value": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]}
    df = pl.DataFrame(data)
    
    # we know these in this case, but we want to generate them dynamically
    breaks = [3.25, 5.5, 7.75]
    labels = ["low", "medium", "high", "very high"] 
    df.with_columns(
        pl.col("value").cut(breaks=breaks, labels=labels).alias("cut_label_pl")
    )
    

    Now lets calculate these breaks. pandas.cut says that bins defines the number of equal-width bins in the range of x.

    def calculate_breakpoints(ser: list | pl.Series, bins: int) -> list:
        if isinstance(ser, list):
            ser = pl.Series(ser)
        min_value, max_value = ser.min(), ser.max() # 1, 10
        bin_size = (max_value - min_value) / bins # (10 - 1) / 4 -> 2.25
        return [min_value + (bin_size * i) for i in range(1, bins)]
    
    # can take a list or a Polars Series
    calculate_breakpoints(data["value"], 4) # [3.25, 5.5, 7.75]
    calculate_breakpoints(df["value"], 4) # [3.25, 5.5, 7.75]
    

    So all together (and still works if you change the number of labels)

    data = {"value": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]}
    df = pl.DataFrame(data)
    
    labels = ["low", "medium", "high", "very high"]
    breaks = calculate_breakpoints(df["value"], len(labels))
    df.with_columns(
        pl.col("value").cut(breaks=breaks, labels=labels).alias("cut_label_pl")
    )
    

    Good luck on your migration from pandas to Polars!