Search code examples
pythonpython-polars

Polars make all groups the same size


Question

I'm trying to make all groups for a given data frame have the same size. In Starting point below, I show an example of a data frame that I whish to transform. In Goal I try to demonstrate what I'm trying to achieve. I want to group by the column group, make all groups have a size of 4, and fill 'missing' values with null - I hope it's clear.

I have tried several approaches but have not been able to figure this one out.

Starting point

dfa = pl.DataFrame(data={'group': ['a', 'a', 'a', 'b', 'b', 'c'],
                         'value': ['a1', 'a2', 'a3', 'b1', 'b2', 'c1']})
┌───────┬───────┐
│ group ┆ value │
│ ---   ┆ ---   │
│ str   ┆ str   │
╞═══════╪═══════╡
│ a     ┆ a1    │
│ a     ┆ a2    │
│ a     ┆ a3    │
│ b     ┆ b1    │
│ b     ┆ b2    │
│ c     ┆ c1    │
└───────┴───────┘

Goal

>>> make_groups_uniform(dfa, group_by='group', group_size=4)
┌───────┬───────┐
│ group ┆ value │
│ ---   ┆ ---   │
│ str   ┆ str   │
╞═══════╪═══════╡
│ a     ┆ a1    │
│ a     ┆ a2    │
│ a     ┆ a3    │
│ a     ┆ null  │
│ b     ┆ b1    │
│ b     ┆ b2    │
│ b     ┆ null  │
│ b     ┆ null  │
│ c     ┆ c1    │
│ c     ┆ null  │
│ c     ┆ null  │
│ c     ┆ null  │
└───────┴───────┘

Package version

polars: 1.1.0

Solution

I have based this on @jqurious's answer below.

>>> import polars as pl

>>> dfa = pl.DataFrame(data={'group': ['a', 'a', 'a', 'b', 'b', 'c'],
...                          'value': ['a1', 'a2', 'a3', 'b1', 'b2', 'c1']})

┌───────┬───────┐
│ group ┆ value │
│ ---   ┆ ---   │
│ str   ┆ str   │
╞═══════╪═══════╡
│ a     ┆ a1    │
│ a     ┆ a2    │
│ a     ┆ a3    │
│ b     ┆ b1    │
│ b     ┆ b2    │
│ c     ┆ c1    │
└───────┴───────┘

>>> (dfa
...  .with_columns(group_size=pl.col('group')
...                             .count()
...                             .over('group')
...                             .max()
...                             .explode())
...  .group_by('group', maintain_order=True)
...  .agg(pl.all()
...         .append(pl.repeat(None, pl.col('group_size') - pl.len())))
...  .select(pl.exclude('group_size'))
...  .explode(pl.exclude('group')))

┌───────┬───────┐
│ group ┆ value │
│ ---   ┆ ---   │
│ str   ┆ str   │
╞═══════╪═══════╡
│ a     ┆ a1    │
│ a     ┆ a2    │
│ a     ┆ a3    │
│ b     ┆ b1    │
│ b     ┆ b2    │
│ b     ┆ null  │
│ c     ┆ c1    │
│ c     ┆ null  │
│ c     ┆ null  │
└───────┴───────┘

Solution

  • The advantage of the approach below is that we don’t transform original DataFrame (except maybe sorting if you want to rearrange the groups), we only create additional rows and append them back to the original DataFrame.


    I've adjusted my answer a bit, based on assumption that you want size of the group to be max of the size of all groups, but it works as well for fixed group_size.

    • group_by() allows per-group calculation.
    • len() to determine size of the group.
    • repeat_by() creates lists based on previously calculated group size and max() group size.
    • filter() to filter out empty lists for case when we don't need to add extra rows.
    • explode() lists into column.
    • concat() back to existing DataFrame.
    • (optional) sort() if you need groups to be together.
    # you can use fixed group_size instead of pl.col.len.max() as well
    
    pl.concat([
        dfa,
        (
            dfa.group_by("group").len()
            .select(pl.col.group.repeat_by(pl.col.len.max() - pl.col.len))
            .filter(pl.col.group.list.len() != 0)
            .explode("group")
        )
    ], how="diagonal").sort("group")
    
    shape: (9, 2)
    ┌───────┬───────┐
    │ group ┆ value │
    │ ---   ┆ ---   │
    │ str   ┆ str   │
    ╞═══════╪═══════╡
    │ a     ┆ a1    │
    │ a     ┆ a2    │
    │ a     ┆ a3    │
    │ b     ┆ b1    │
    │ b     ┆ b2    │
    │ b     ┆ null  │
    │ c     ┆ c1    │
    │ c     ┆ null  │
    │ c     ┆ null  │
    └───────┴───────┘
    

    If you need fixed group size, then repeat() is probably more performant, but the idea is the same - only generate additional rows and append them back to original DataFrame.

    group_size = 3
    
    # you can make it dynamic as well though
    # group_size = dfa.group_by("group").len().max()["len"].item()
    
    pl.concat([
        dfa,
        (
            dfa.group_by("group")
            .agg(value = pl.repeat(None, group_size - pl.len().cast(int)))
            .filter(pl.col.value.list.len() != 0)
            .explode("value")
        )
    ]).sort("group")
    
    shape: (9, 2)
    ┌───────┬───────┐
    │ group ┆ value │
    │ ---   ┆ ---   │
    │ str   ┆ str   │
    ╞═══════╪═══════╡
    │ a     ┆ a1    │
    │ a     ┆ a2    │
    │ a     ┆ a3    │
    │ b     ┆ b1    │
    │ b     ┆ b2    │
    │ b     ┆ null  │
    │ c     ┆ c1    │
    │ c     ┆ null  │
    │ c     ┆ null  │
    └───────┴───────┘