Search code examples
python-polars

Assign categorical column based on maximum value in group where other value is non-null


Suppose that a polars DataFrame df contains a set of columns group_cols, and then 2 other columns col1, col2.

I want to add a new column to this called category which takes values either 'primary' or 'secondary'. There should be exactly one occurrence of 'primary' per group in group_cols, and it should be chosen where col1 is maximal BUT col2 is not null. (But if col2 is always null, then it should simply be chosen where col1 is maximal).

So e.g. if we have in the same group 4 rows with values:

  • col1 = 2, col2 = None
  • col1 = 1, col2 = 'not null'
  • col1 = 1, col2 = 'not null also'
  • col1 = 0, col2 = 'not null either'

then either the 2nd or 3rd row in the group (it does not matter which, but only one of them) should have their category column value as 'primary' and the rest should take values 'secondary' in this column.

I can do this with pandas quite easily by the following:

def assign_category(group):
    if group['col2'].notnull().any():
        max_col1_row = group.loc[group['col2'].notnull(), 'col1'].idxmax()
        category_column = pd.Series('secondary', index=group.index)
        if max_col1_row is not None:
            category_column[max_col1_row] = 'primary'
    else:
        max_col1_row = group['col1'].idxmax()
        category_column = pd.Series('secondary', index=group.index)
        category_column[max_col1_row] = 'primary'
    return category_column

df['category'] = df.groupby('group_cols').apply(assign_category).reset_index(drop=True)

and here is a reproducible example:

df = pd.DataFrame({
    'group_cols': ['A', 'A', 'A', 'A', 'B', 'B', 'C', 'C', 'C', 'C', 'C'],
    'col1': [2, 1, 1, 0, 3, 2, 3, None, 4, 4, 1],
    'col2': [None, 'not null', 'also not null', 'not null either', 'not null', 'not null', None, None, None, None, None]
})

becomes

pd.DataFrame({'group_cols': ['A', 'A', 'A', 'A', 'B', 'B', 'C', 'C', 'C', 'C', 'C'],
 'col1': [2.0, 1.0, 1.0, 0.0, 3.0, 2.0, 3.0, nan, 4.0, 4.0, 1.0],
 'col2': [None,
  'not null',
  'also not null',
  'not null either',
  'not null',
  'not null',
  None,
  None,
  None,
  None,
  None],
 'category': ['secondary',
  'primary',
  'secondary',
  'secondary',
  'primary',
  'secondary',
  'secondary',
  'secondary',
  'primary',
  'secondary',
  'secondary']})

But I am struggling to even start replicating this kind of logic in polars, let alone without making use of the discouraged map_groups.


Solution

  • You can use the following (thanks to jqurious for pointing out how this can be done with a single over):

    # %%
    import polars as pl
    from IPython.display import display
    
    df = pl.DataFrame(
        {
            "group_cols": ["A", "A", "A", "A", "B", "B", "C", "C", "C", "C", "C"],
            "col1": [2, 1, 1, 0, 3, 2, 3, None, 4, 4, 1],
            "col2": [
                None,
                "not null",
                "also not null",
                "not null either",
                "not null",
                "not null",
                None,
                None,
                None,
                None,
                None,
            ],
        }
    )
    
    df = df.with_columns(
        category=pl.when(pl.col("col2").is_null().all())
        .then(
            pl.when(pl.int_range(pl.len()) == pl.col("col1").arg_max())
            .then(pl.lit("primary"))
            .otherwise(pl.lit("secondary"))
        )
        .when(
            pl.int_range(pl.len())
            == pl.arg_where(
                pl.col("col1") == pl.col("col1").filter(pl.col("col2").is_not_null()).max()
            ).first()
        )
        .then(pl.lit("primary"))
        .otherwise(pl.lit("secondary"))
        .over(["group_cols"])
    )
    
    with pl.Config(tbl_rows=-1):
        display(df)
    

    This acts first on groups, where col2 is entirely null, choosing the first max value. For the other groups, it takes the first max value where col2 is non null.

    shape: (11, 4)
    ┌────────────┬──────┬─────────────────┬───────────┐
    │ group_cols ┆ col1 ┆ col2            ┆ category  │
    │ ---        ┆ ---  ┆ ---             ┆ ---       │
    │ str        ┆ i64  ┆ str             ┆ str       │
    ╞════════════╪══════╪═════════════════╪═══════════╡
    │ A          ┆ 2    ┆ null            ┆ secondary │
    │ A          ┆ 1    ┆ not null        ┆ primary   │
    │ A          ┆ 1    ┆ also not null   ┆ secondary │
    │ A          ┆ 0    ┆ not null either ┆ secondary │
    │ B          ┆ 3    ┆ not null        ┆ primary   │
    │ B          ┆ 2    ┆ not null        ┆ secondary │
    │ C          ┆ 3    ┆ null            ┆ secondary │
    │ C          ┆ null ┆ null            ┆ secondary │
    │ C          ┆ 4    ┆ null            ┆ primary   │
    │ C          ┆ 4    ┆ null            ┆ secondary │
    │ C          ┆ 1    ┆ null            ┆ secondary │
    └────────────┴──────┴─────────────────┴───────────┘