If I have data like this (which is a mock example)
data = {'col1': [1, 1, 1, 1, 1, 2, 2], 'col2': ['br', 'bra', 'bra', 'col', 'col', 'br', 'b'], 'col3': ['brazil', 'brazil', 'brasil', 'collombia', 'columbia', 'brazil', 'brazil']}
df = pl.DataFrame(data)
How can I do a groupby
on it such that a row belongs to a group if 1. its col1
values are equal, and 2. its col2
value OR its col3
value is equal to at least 1 other row in the group?
So for example:
col1
= 1, and the values col2
= 'bra', col3
= 'brazil' are duplicatedcol1
= 1, and the value col2
= 'col' are duplicatedcol1
= 2 and col3
= 'brazil' is duplicatedSo such an operation would yield
{'col1': [1, 1, 1, 1, 1, 2, 2], 'col2': ['br', 'bra', 'bra', 'col', 'col', 'br', 'b'], 'col3': ['brazil', 'brazil', 'brasil', 'collombia', 'columbia', 'brazil', 'brazil'], 'group_idx': [1, 1, 1, 2, 2, 3, 3],}
Your question is not entirely clear, so I'm making some assumptions.
First, I'm assuming that your condition is: grouping has to happen within col1
, and rows belong to one group if they have one of col2
, col3
matching.
Second, it's not clear if you want to group only consecutive rows or group them regardless of their position.
So here are you 2 options:
Grouping rows regardless of their position.
To illustrate, DataFrame like this:
┌──────┬──────┬────────┐
│ col1 ┆ col2 ┆ col3 │
│ --- ┆ --- ┆ --- │
│ i64 ┆ str ┆ str │
╞══════╪══════╪════════╡
│ 1 ┆ bra ┆ brazil │
│ 1 ┆ br ┆ brasil │
│ 1 ┆ bra ┆ brazil │
│ 1 ┆ br ┆ brazil │
└──────┴──────┴────────┘
Will have group_idx
as follows:
┌──────┬──────┬────────┬───────────┐
│ col1 ┆ col2 ┆ col3 ┆ group_idx │
│ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ str ┆ str ┆ i64 │
╞══════╪══════╪════════╪═══════════╡
│ 1 ┆ bra ┆ brazil ┆ 1 │
│ 1 ┆ br ┆ brasil ┆ 1 │
│ 1 ┆ bra ┆ brazil ┆ 1 │
│ 1 ┆ br ┆ brazil ┆ 1 │
└──────┴──────┴────────┴───────────┘
Essentially, you can imagine your table rows as nodes of the graph, and nodes have an edge linking them if they have the same col1
and one of col2
, col3
in common. What we want is to find all connected subgraphs (see also Flood fill algorithm).
To achieve that you can self-join the DataFrame on col1
+ col2
and col1
+ col3
and assign group_idx
, and do it until you have merged all the rows which have something in common via other rows:
dfi = df.with_row_index('group_idx')
while True:
for col in ['col2','col3']:
dfi = (
dfi
.group_by('col1', col)
.agg(pl.col('group_idx').min())
.join(dfi, on=['col1', col], suffix=col)
)
if dfi.filter(
pl.min_horizontal('group_idxcol2', 'group_idxcol3') > pl.col('group_idx')
).is_empty():
break
dfi = dfi.drop('group_idxcol2', 'group_idxcol3')
dfi.select('col1','col2','col3', pl.col('group_idx').rank('dense'))
┌──────┬──────┬───────────┬───────────┐
│ col1 ┆ col2 ┆ col3 ┆ group_idx │
│ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ str ┆ str ┆ u32 │
╞══════╪══════╪═══════════╪═══════════╡
│ 1 ┆ br ┆ brazil ┆ 1 │
│ 1 ┆ bra ┆ brazil ┆ 1 │
│ 1 ┆ bra ┆ brasil ┆ 1 │
│ 1 ┆ col ┆ collombia ┆ 2 │
│ 1 ┆ col ┆ columbia ┆ 2 │
│ 2 ┆ br ┆ brazil ┆ 3 │
│ 2 ┆ b ┆ brazil ┆ 3 │
└──────┴──────┴───────────┴───────────┘
Gouping consecutive rows only.
To illustrate again, DataFrame like this:
┌──────┬──────┬────────┐
│ col1 ┆ col2 ┆ col3 │
│ --- ┆ --- ┆ --- │
│ i64 ┆ str ┆ str │
╞══════╪══════╪════════╡
│ 1 ┆ bra ┆ brazil │
│ 1 ┆ br ┆ brasil │
│ 1 ┆ bra ┆ brazil │
│ 1 ┆ br ┆ brazil │
└──────┴──────┴────────┘
Will have group_idx
as follows:
┌──────┬──────┬────────┬───────────┐
│ col1 ┆ col2 ┆ col3 ┆ group_idx │
│ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ str ┆ str ┆ i64 │
╞══════╪══════╪════════╪═══════════╡
│ 1 ┆ bra ┆ brazil ┆ 1 │
│ 1 ┆ br ┆ brasil ┆ 2 │ <-- new group cause col2, col3 do not match
│ 1 ┆ bra ┆ brazil ┆ 3 │ <-- new group cause col2, col3 do not match
│ 1 ┆ br ┆ brazil ┆ 3 │
└──────┴──────┴────────┴───────────┘
If these assumptions are correct, then you can compare row values with previous values using Expr.shift()
to check whether the condition is still holds (col1
is still the same, one of col2
or col3
is still the same).
So, first you can imagine the following DataFrame:
(
df
.with_columns(
col1_change=pl.col('col1') != pl.col('col1').shift(1),
col2_change=pl.col('col2') != pl.col('col2').shift(1),
col3_change=pl.col('col3') != pl.col('col3').shift(1),
).fill_null(True)
)
┌──────┬──────┬───────────┬─────────────┬─────────────┬─────────────┐
│ col1 ┆ col2 ┆ col3 ┆ col1_change ┆ col2_change ┆ col3_change │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ str ┆ str ┆ bool ┆ bool ┆ bool │
╞══════╪══════╪═══════════╪═════════════╪═════════════╪═════════════╡
│ 1 ┆ br ┆ brazil ┆ true ┆ true ┆ true │
│ 1 ┆ bra ┆ brazil ┆ false ┆ true ┆ false │
│ 1 ┆ bra ┆ brasil ┆ false ┆ false ┆ true │
│ 1 ┆ col ┆ collombia ┆ false ┆ true ┆ true │
│ 1 ┆ col ┆ columbia ┆ false ┆ false ┆ true │
│ 2 ┆ br ┆ brazil ┆ true ┆ true ┆ true │
│ 2 ┆ b ┆ brazil ┆ false ┆ true ┆ false │
└──────┴──────┴───────────┴─────────────┴─────────────┴─────────────┘
Now, we don't care if col2
changes but col3
doesn't and vice versa. To adjust this condition we can use .min_horizontal()
:
(
df
.with_columns(
col1_change=pl.col('col1') != pl.col('col1').shift(1),
col2_or_col3_change=pl.min_horizontal(pl.col('col2','col3') != pl.col('col2','col3').shift(1))
).fill_null(True)
)
┌──────┬──────┬───────────┬─────────────┬─────────────────────┐
│ col1 ┆ col2 ┆ col3 ┆ col1_change ┆ col2_or_col3_change │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ str ┆ str ┆ bool ┆ bool │
╞══════╪══════╪═══════════╪═════════════╪═════════════════════╡
│ 1 ┆ br ┆ brazil ┆ true ┆ true │
│ 1 ┆ bra ┆ brazil ┆ false ┆ false │
│ 1 ┆ bra ┆ brasil ┆ false ┆ false │
│ 1 ┆ col ┆ collombia ┆ false ┆ true │
│ 1 ┆ col ┆ columbia ┆ false ┆ false │
│ 2 ┆ br ┆ brazil ┆ true ┆ true │
│ 2 ┆ b ┆ brazil ┆ false ┆ false │
└──────┴──────┴───────────┴─────────────┴─────────────────────┘
Now, we need to start new group when one of these conditions is true, and for that we can use either |
or .max_horizontal()
. I'm using .max_horizontal()
cause it allows me to switch from DataFramw.fill_null()
to Expr.fill_null()
and it's going to be handy during the next step:
(
df
.with_columns(
group_start = pl.max_horizontal(
(pl.col('col1') != pl.col('col1').shift(1)) |
pl.min_horizontal(pl.col('col2','col3') != pl.col('col2','col3').shift(1))
).fill_null(True)
)
)
┌──────┬──────┬───────────┬─────────────┐
│ col1 ┆ col2 ┆ col3 ┆ group_start │
│ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ str ┆ str ┆ bool │
╞══════╪══════╪═══════════╪═════════════╡
│ 1 ┆ br ┆ brazil ┆ true │
│ 1 ┆ bra ┆ brazil ┆ false │
│ 1 ┆ bra ┆ brasil ┆ false │
│ 1 ┆ col ┆ collombia ┆ true │
│ 1 ┆ col ┆ columbia ┆ false │
│ 2 ┆ br ┆ brazil ┆ true │
│ 2 ┆ b ┆ brazil ┆ false │
└──────┴──────┴───────────┴─────────────┘
And now, you can use Expr.cum_sum()
to enumerate groups:
(
df
.with_columns(
group_idx = pl.max_horizontal(
(pl.col('col1') != pl.col('col1').shift(1)) |
pl.min_horizontal(pl.col('col2','col3') != pl.col('col2','col3').shift(1))
).fill_null(True).cum_sum()
)
)
┌──────┬──────┬───────────┬───────────┐
│ col1 ┆ col2 ┆ col3 ┆ group_idx │
│ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ str ┆ str ┆ u32 │
╞══════╪══════╪═══════════╪═══════════╡
│ 1 ┆ br ┆ brazil ┆ 1 │
│ 1 ┆ bra ┆ brazil ┆ 1 │
│ 1 ┆ bra ┆ brasil ┆ 1 │
│ 1 ┆ col ┆ collombia ┆ 2 │
│ 1 ┆ col ┆ columbia ┆ 2 │
│ 2 ┆ br ┆ brazil ┆ 3 │
│ 2 ┆ b ┆ brazil ┆ 3 │
└──────┴──────┴───────────┴───────────┘