Search code examples
pythonpython-polars

Using `hist` to bin data while grouping with `over`?


Consider the following example:

import polars as pl

df = pl.DataFrame(
    [
        pl.Series(
            "name", ["A", "B", "C", "D"], dtype=pl.Enum(["A", "B", "C", "D"])
        ),
        pl.Series("month", [1, 2, 12, 1], dtype=pl.Int8()),
        pl.Series(
            "category", ["x", "x", "y", "z"], dtype=pl.Enum(["x", "y", "z"])
        ),
    ]
)
print(df)
shape: (4, 3)
┌──────┬───────┬──────────┐
│ name ┆ month ┆ category │
│ ---  ┆ ---   ┆ ---      │
│ enum ┆ i8    ┆ enum     │
╞══════╪═══════╪══════════╡
│ A    ┆ 1     ┆ x        │
│ B    ┆ 2     ┆ x        │
│ C    ┆ 12    ┆ y        │
│ D    ┆ 1     ┆ z        │
└──────┴───────┴──────────┘

We can count the number of months in the dataframe that match each month of the year:

from math import inf


binned_df = (
    df.select(
        pl.col.month.hist(
            bins=[x + 1 for x in range(11)],
            include_breakpoint=True,
        ).alias("binned"),
    )
    .unnest("binned")
    .with_columns(
        pl.col.breakpoint.map_elements(
            lambda x: 12 if x == inf else x, return_dtype=pl.Float64()
        )
        .cast(pl.Int8())
        .alias("month")
    )
    .drop("breakpoint")
    .select("month", "count")
)
print(binned_df)
shape: (12, 2)
┌───────┬───────┐
│ month ┆ count │
│ ---   ┆ ---   │
│ i8    ┆ u32   │
╞═══════╪═══════╡
│ 1     ┆ 2     │
│ 2     ┆ 1     │
│ 3     ┆ 0     │
│ 4     ┆ 0     │
│ 5     ┆ 0     │
│ …     ┆ …     │
│ 8     ┆ 0     │
│ 9     ┆ 0     │
│ 10    ┆ 0     │
│ 11    ┆ 0     │
│ 12    ┆ 1     │
└───────┴───────┘

(Note: there are 3 categories "x", "y", and "z", so we expect a dataframe of shape 12 x 3 = 36.)

Suppose I want to bin the data per the column "category". I can do the following:

# initialize an empty dataframe
category_binned_df = pl.DataFrame()

for cat in df["category"].unique():
    # repeat the binning logic from earlier, except on a dataframe filtered for
    # the particular category we are iterating over
    binned_df = (
        df.filter(pl.col.category.eq(cat))  # <--- the filter
        .select(
            pl.col.month.hist(
                bins=[x + 1 for x in range(11)],
                include_breakpoint=True,
            ).alias("binned"),
        )
        .unnest("binned")
        .with_columns(
            pl.col.breakpoint.map_elements(
                lambda x: 12 if x == inf else x, return_dtype=pl.Float64()
            )
            .cast(pl.Int8())
            .alias("month")
        )
        .drop("breakpoint")
        .select("month", "count")
        .with_columns(category=pl.lit(cat).cast(df["category"].dtype))
    )
    # finally, vstack ("append") the resulting dataframe 
    category_binned_df = category_binned_df.vstack(binned_df)
print(category_binned_df)
shape: (36, 3)
┌───────┬───────┬──────────┐
│ month ┆ count ┆ category │
│ ---   ┆ ---   ┆ ---      │
│ i8    ┆ u32   ┆ enum     │
╞═══════╪═══════╪══════════╡
│ 1     ┆ 1     ┆ x        │
│ 2     ┆ 1     ┆ x        │
│ 3     ┆ 0     ┆ x        │
│ 4     ┆ 0     ┆ x        │
│ 5     ┆ 0     ┆ x        │
│ …     ┆ …     ┆ …        │
│ 8     ┆ 0     ┆ z        │
│ 9     ┆ 0     ┆ z        │
│ 10    ┆ 0     ┆ z        │
│ 11    ┆ 0     ┆ z        │
│ 12    ┆ 1     ┆ z        │
└───────┴───────┴──────────┘

It seems to me that there should be a way to do this using over, something like pl.col.month.hist(bins=...).over("category"), but the very first step of trying to do so raises an error:

df.select(
    pl.col.month.hist(
        bins=[x + 1 for x in range(11)],
        include_breakpoint=True,
    )
    .over("category")
    .alias("binned"),
)
ComputeError: the length of the window expression did not match that of the group

Error originated in expression: 'col("month").hist([Series]).over([col("category")])'

So there's some sort of conceptual error I am making when thinking of over? Is there a way to use over here at all?


Solution

  • Here's one approach using Expr.over:

    bins = range(1,12)
    
    out = df.select(
        pl.col('month').hist(
            bins=bins,
            include_breakpoint=True
        )
        .over(partition_by='category', mapping_strategy='explode')
        .alias('binned'),
        pl.col('category').unique(maintain_order=True).repeat_by(len(bins)+1).flatten()
        ).unnest('binned').with_columns(
        pl.col('breakpoint').replace(float('inf'), 12).cast(int)
        ).rename({'breakpoint': 'month'})
    

    Output:

    shape: (36, 3)
    ┌───────┬───────┬──────────┐
    │ month ┆ count ┆ category │
    │ ---   ┆ ---   ┆ ---      │
    │ i64   ┆ u32   ┆ enum     │
    ╞═══════╪═══════╪══════════╡
    │ 1     ┆ 1     ┆ x        │
    │ 2     ┆ 1     ┆ x        │
    │ 3     ┆ 0     ┆ x        │
    │ 4     ┆ 0     ┆ x        │
    │ 5     ┆ 0     ┆ x        │
    │ …     ┆ …     ┆ …        │
    │ 8     ┆ 0     ┆ z        │
    │ 9     ┆ 0     ┆ z        │
    │ 10    ┆ 0     ┆ z        │
    │ 11    ┆ 0     ┆ z        │
    │ 12    ┆ 1     ┆ z        │
    └───────┴───────┴──────────┘
    

    Explanation

    • The key is to use mapping_strategy='explode'. As mentioned in the docs, under explode:

    Explodes the grouped data into new rows, similar to the results of group_by + agg + explode. Sorting of the given groups is required if the groups are not part of the window operation for the operation, otherwise the result would not make sense. This operation changes the number of rows.

    (I do not think sorting is required here, but anyone please correct me if I am wrong.)

    Adding a performance comparison with the method suggested in the answer by @BallpointBen, testing:

    • over: over, maintaining order + 'category'
    • over_ex_cat: over, maintaining order, ex 'category' (highlights the bottleneck)
    • group_by: group_by, not maintaining order + 'category'
    • group_by_order: group_by, maintaining order + 'category'

    I've left out trivial operations like renaming "breakpoint" column and getting the columns in the same order. Script can be found here (updated for second plot below).

    perfplot

    Maybe someone can suggest a better way to get back the categories. Otherwise, there does not seem to be too much between the two methods.


    Update: performance comparison with suggested answers by @HenryHarback, testing:

    • over: over, maintaining order + 'category' (= mapping_strategy='explode')
    • over_join: over, not maintaining order + 'category' (= mapping_strategy='join')
    • spine: cross join + left join, not maintaining order + 'category'

    Not included is the group_by option + select + struct, which has similar performance to group_by compared above (with unnest). Extended the N-range to show spine catching up with, though apparently not overtaking, over, if the df gets really big.

    perfplot2