Search code examples

Group DataFrame Rows Based on Multiple Column Matching Conditions

If I have data like this (which is a mock example)

data = {'col1': [1, 1, 1, 1, 1, 2, 2], 'col2': ['br', 'bra', 'bra', 'col', 'col', 'br', 'b'], 'col3': ['brazil', 'brazil', 'brasil', 'collombia', 'columbia', 'brazil', 'brazil']}
df = pl.DataFrame(data)

How can I do a groupby on it such that a row belongs to a group if 1. its col1 values are equal, and 2. its col2 value OR its col3 value is equal to at least 1 other row in the group?

So for example:

  1. The first 3 rows are a group because col1 = 1, and the values col2 = 'bra', col3 = 'brazil' are duplicated
  2. Rows 4-5 (1-based indexing) are a group because col1 = 1, and the value col2 = 'col' are duplicated
  3. Rows 6-7 are a group because col1 = 2 and col3 = 'brazil' is duplicated

So such an operation would yield

{'col1': [1, 1, 1, 1, 1, 2, 2], 'col2': ['br', 'bra', 'bra', 'col', 'col', 'br', 'b'], 'col3': ['brazil', 'brazil', 'brasil', 'collombia', 'columbia', 'brazil', 'brazil'], 'group_idx': [1, 1, 1, 2, 2, 3, 3],}


  • Your question is not entirely clear, so I'm making some assumptions.

    First, I'm assuming that your condition is: grouping has to happen within col1, and rows belong to one group if they have one of col2, col3 matching.

    Second, it's not clear if you want to group only consecutive rows or group them regardless of their position.

    So here are you 2 options:

    Grouping rows regardless of their position.

    To illustrate, DataFrame like this:

    │ col1 ┆ col2 ┆ col3   │
    │ ---  ┆ ---  ┆ ---    │
    │ i64  ┆ str  ┆ str    │
    │ 1    ┆ bra  ┆ brazil │
    │ 1    ┆ br   ┆ brasil │
    │ 1    ┆ bra  ┆ brazil │
    │ 1    ┆ br   ┆ brazil │

    Will have group_idx as follows:

    │ col1 ┆ col2 ┆ col3   ┆ group_idx │
    │ ---  ┆ ---  ┆ ---    ┆ ---       │
    │ i64  ┆ str  ┆ str    ┆ i64       │
    │ 1    ┆ bra  ┆ brazil ┆ 1         │
    │ 1    ┆ br   ┆ brasil ┆ 1         │
    │ 1    ┆ bra  ┆ brazil ┆ 1         │
    │ 1    ┆ br   ┆ brazil ┆ 1         │

    Essentially, you can imagine your table rows as nodes of the graph, and nodes have an edge linking them if they have the same col1 and one of col2, col3 in common. What we want is to find all connected subgraphs (see also Flood fill algorithm).

    To achieve that you can self-join the DataFrame on col1 + col2 and col1 + col3 and assign group_idx, and do it until you have merged all the rows which have something in common via other rows:

    dfi = df.with_row_index('group_idx')
    while True:
        for col in ['col2','col3']:
            dfi = (
                .group_by('col1', col)
                .join(dfi, on=['col1', col], suffix=col)
        if dfi.filter(
               pl.min_horizontal('group_idxcol2', 'group_idxcol3') > pl.col('group_idx')
        dfi = dfi.drop('group_idxcol2', 'group_idxcol3')'col1','col2','col3', pl.col('group_idx').rank('dense'))
    │ col1 ┆ col2 ┆ col3      ┆ group_idx │
    │ ---  ┆ ---  ┆ ---       ┆ ---       │
    │ i64  ┆ str  ┆ str       ┆ u32       │
    │ 1    ┆ br   ┆ brazil    ┆ 1         │
    │ 1    ┆ bra  ┆ brazil    ┆ 1         │
    │ 1    ┆ bra  ┆ brasil    ┆ 1         │
    │ 1    ┆ col  ┆ collombia ┆ 2         │
    │ 1    ┆ col  ┆ columbia  ┆ 2         │
    │ 2    ┆ br   ┆ brazil    ┆ 3         │
    │ 2    ┆ b    ┆ brazil    ┆ 3         │

    Gouping consecutive rows only.

    To illustrate again, DataFrame like this:

    │ col1 ┆ col2 ┆ col3   │
    │ ---  ┆ ---  ┆ ---    │
    │ i64  ┆ str  ┆ str    │
    │ 1    ┆ bra  ┆ brazil │
    │ 1    ┆ br   ┆ brasil │
    │ 1    ┆ bra  ┆ brazil │
    │ 1    ┆ br   ┆ brazil │

    Will have group_idx as follows:

    │ col1 ┆ col2 ┆ col3   ┆ group_idx │
    │ ---  ┆ ---  ┆ ---    ┆ ---       │
    │ i64  ┆ str  ┆ str    ┆ i64       │
    │ 1    ┆ bra  ┆ brazil ┆ 1         │
    │ 1    ┆ br   ┆ brasil ┆ 2         │ <-- new group cause col2, col3 do not match
    │ 1    ┆ bra  ┆ brazil ┆ 3         │ <-- new group cause col2, col3 do not match
    │ 1    ┆ br   ┆ brazil ┆ 3         │

    If these assumptions are correct, then you can compare row values with previous values using Expr.shift() to check whether the condition is still holds (col1 is still the same, one of col2 or col3 is still the same).

    So, first you can imagine the following DataFrame:

            col1_change=pl.col('col1') != pl.col('col1').shift(1),
            col2_change=pl.col('col2') != pl.col('col2').shift(1),
            col3_change=pl.col('col3') != pl.col('col3').shift(1),
    │ col1 ┆ col2 ┆ col3      ┆ col1_change ┆ col2_change ┆ col3_change │
    │ ---  ┆ ---  ┆ ---       ┆ ---         ┆ ---         ┆ ---         │
    │ i64  ┆ str  ┆ str       ┆ bool        ┆ bool        ┆ bool        │
    │ 1    ┆ br   ┆ brazil    ┆ true        ┆ true        ┆ true        │
    │ 1    ┆ bra  ┆ brazil    ┆ false       ┆ true        ┆ false       │
    │ 1    ┆ bra  ┆ brasil    ┆ false       ┆ false       ┆ true        │
    │ 1    ┆ col  ┆ collombia ┆ false       ┆ true        ┆ true        │
    │ 1    ┆ col  ┆ columbia  ┆ false       ┆ false       ┆ true        │
    │ 2    ┆ br   ┆ brazil    ┆ true        ┆ true        ┆ true        │
    │ 2    ┆ b    ┆ brazil    ┆ false       ┆ true        ┆ false       │

    Now, we don't care if col2 changes but col3 doesn't and vice versa. To adjust this condition we can use .min_horizontal():

            col1_change=pl.col('col1') != pl.col('col1').shift(1),
            col2_or_col3_change=pl.min_horizontal(pl.col('col2','col3') != pl.col('col2','col3').shift(1))
    │ col1 ┆ col2 ┆ col3      ┆ col1_change ┆ col2_or_col3_change │
    │ ---  ┆ ---  ┆ ---       ┆ ---         ┆ ---                 │
    │ i64  ┆ str  ┆ str       ┆ bool        ┆ bool                │
    │ 1    ┆ br   ┆ brazil    ┆ true        ┆ true                │
    │ 1    ┆ bra  ┆ brazil    ┆ false       ┆ false               │
    │ 1    ┆ bra  ┆ brasil    ┆ false       ┆ false               │
    │ 1    ┆ col  ┆ collombia ┆ false       ┆ true                │
    │ 1    ┆ col  ┆ columbia  ┆ false       ┆ false               │
    │ 2    ┆ br   ┆ brazil    ┆ true        ┆ true                │
    │ 2    ┆ b    ┆ brazil    ┆ false       ┆ false               │

    Now, we need to start new group when one of these conditions is true, and for that we can use either | or .max_horizontal(). I'm using .max_horizontal() cause it allows me to switch from DataFramw.fill_null() to Expr.fill_null() and it's going to be handy during the next step:

            group_start = pl.max_horizontal(
                (pl.col('col1') != pl.col('col1').shift(1)) |
                pl.min_horizontal(pl.col('col2','col3') != pl.col('col2','col3').shift(1))
    │ col1 ┆ col2 ┆ col3      ┆ group_start │
    │ ---  ┆ ---  ┆ ---       ┆ ---         │
    │ i64  ┆ str  ┆ str       ┆ bool        │
    │ 1    ┆ br   ┆ brazil    ┆ true        │
    │ 1    ┆ bra  ┆ brazil    ┆ false       │
    │ 1    ┆ bra  ┆ brasil    ┆ false       │
    │ 1    ┆ col  ┆ collombia ┆ true        │
    │ 1    ┆ col  ┆ columbia  ┆ false       │
    │ 2    ┆ br   ┆ brazil    ┆ true        │
    │ 2    ┆ b    ┆ brazil    ┆ false       │

    And now, you can use Expr.cum_sum() to enumerate groups:

            group_idx = pl.max_horizontal(
                (pl.col('col1') != pl.col('col1').shift(1)) |
                pl.min_horizontal(pl.col('col2','col3') != pl.col('col2','col3').shift(1))
    │ col1 ┆ col2 ┆ col3      ┆ group_idx │
    │ ---  ┆ ---  ┆ ---       ┆ ---       │
    │ i64  ┆ str  ┆ str       ┆ u32       │
    │ 1    ┆ br   ┆ brazil    ┆ 1         │
    │ 1    ┆ bra  ┆ brazil    ┆ 1         │
    │ 1    ┆ bra  ┆ brasil    ┆ 1         │
    │ 1    ┆ col  ┆ collombia ┆ 2         │
    │ 1    ┆ col  ┆ columbia  ┆ 2         │
    │ 2    ┆ br   ┆ brazil    ┆ 3         │
    │ 2    ┆ b    ┆ brazil    ┆ 3         │