Search code examples
dataframepython-polars

Properly groupby and filter with Polars


I have df for my work with 3 main columns: cid1, cid2, cid3, and more 7 columns cid4, cid5, etc.

cid1 and cid2 is int, another columns is float.

Each combitations of cid1 and cid2 is a workset with some rows where is values of all other columns is different. I want to filter df and receive my df with only max values in column cid3 for each combination of cid1 and cid2. cid4 and next columns must be leaved without changes.

This code helps me with one part of my task:

df = (df
    .group_by("cid1", "cid2")
    .agg(pl.max("cid3").alias("max_cid3"))
)

It's receives only 3 columns: cid1, cid2, max_cid3 and filter all rows when cid3 is not maximal. But I can't find how to receive all another columns (cid4, etc) for that rows without changes.

df = (df
    .group_by("cid1", "cid2")
    .agg(pl.max("cid3").alias("max_cid3"), pl.col("cid4"))
)

I tried to add pl.col("cid4") to list of aggs but in column I see as values different lists of some cid4 values.

How I can make it properly? Maybe Polars haves another way to make it then group_by?

In Pandas I can make it:

import pandas as pd
import numpy as np

df["max_cid3"] = df.groupby(['cid1', 'cid2'])['cid3'].transform(np.max)

And then filter df wherever cid3==max_cid3 But I can't find a way to make it in Polars.

Thank you!


Solution

  • In polars you can use Window functions.

    df = pl.DataFrame(
        {"cid1": [1, 2, 2, 1, 3],
         "cid2": [1, 2, 2, 1, 3],
         "cid3": [1, 2, 9, 1, 1],
         "cid4": [4, 5, 6, 7, 8],
         "cid5": [4, 5, 4, 9, 3]}
    )
    
    df.with_columns(max_cid3 = 
       pl.col("cid3").max().over("cid1", "cid2")
    )
    
    shape: (5, 6)
    ┌──────┬──────┬──────┬──────┬──────┬──────────┐
    │ cid1 ┆ cid2 ┆ cid3 ┆ cid4 ┆ cid5 ┆ max_cid3 │
    │ ---  ┆ ---  ┆ ---  ┆ ---  ┆ ---  ┆ ---      │
    │ i64  ┆ i64  ┆ i64  ┆ i64  ┆ i64  ┆ i64      │
    ╞══════╪══════╪══════╪══════╪══════╪══════════╡
    │ 1    ┆ 1    ┆ 1    ┆ 4    ┆ 4    ┆ 1        │
    │ 2    ┆ 2    ┆ 2    ┆ 5    ┆ 5    ┆ 9        │
    │ 2    ┆ 2    ┆ 9    ┆ 6    ┆ 4    ┆ 9        │
    │ 1    ┆ 1    ┆ 1    ┆ 7    ┆ 9    ┆ 1        │
    │ 3    ┆ 3    ┆ 1    ┆ 8    ┆ 3    ┆ 1        │
    └──────┴──────┴──────┴──────┴──────┴──────────┘
    

    You can use it directly inside .filter()

    df.filter(cid3 = pl.col("cid3").max().over("cid1", "cid2"))
    
    shape: (4, 5)
    ┌──────┬──────┬──────┬──────┬──────┐
    │ cid1 ┆ cid2 ┆ cid3 ┆ cid4 ┆ cid5 │
    │ ---  ┆ ---  ┆ ---  ┆ ---  ┆ ---  │
    │ i64  ┆ i64  ┆ i64  ┆ i64  ┆ i64  │
    ╞══════╪══════╪══════╪══════╪══════╡
    │ 1    ┆ 1    ┆ 1    ┆ 4    ┆ 4    │
    │ 2    ┆ 2    ┆ 9    ┆ 6    ┆ 4    │
    │ 1    ┆ 1    ┆ 1    ┆ 7    ┆ 9    │
    │ 3    ┆ 3    ┆ 1    ┆ 8    ┆ 3    │
    └──────┴──────┴──────┴──────┴──────┘
    

    Pandas for comparison

    >>> df.to_pandas().groupby(["cid1", "cid2"])["cid3"].transform("max")
    0    1
    1    9
    2    9
    3    1
    4    1
    Name: cid3, dtype: int64