Suppose that a polars DataFrame df
contains a set of columns group_cols
, and then 2 other columns col1
, col2
.
I want to add a new column to this called category
which takes values either 'primary'
or 'secondary'
. There should be exactly one occurrence of 'primary'
per group in group_cols
, and it should be chosen where col1
is maximal BUT col2
is not null. (But if col2
is always null, then it should simply be chosen where col1
is maximal).
So e.g. if we have in the same group 4 rows with values:
col1
= 2
, col2
= None
col1
= 1
, col2
= 'not null'
col1
= 1
, col2
= 'not null also'
col1
= 0
, col2
= 'not null either'
then either the 2nd or 3rd row in the group (it does not matter which, but only one of them) should have their category
column value as 'primary'
and the rest should take values 'secondary'
in this column.
I can do this with pandas quite easily by the following:
def assign_category(group):
if group['col2'].notnull().any():
max_col1_row = group.loc[group['col2'].notnull(), 'col1'].idxmax()
category_column = pd.Series('secondary', index=group.index)
if max_col1_row is not None:
category_column[max_col1_row] = 'primary'
else:
max_col1_row = group['col1'].idxmax()
category_column = pd.Series('secondary', index=group.index)
category_column[max_col1_row] = 'primary'
return category_column
df['category'] = df.groupby('group_cols').apply(assign_category).reset_index(drop=True)
and here is a reproducible example:
df = pd.DataFrame({
'group_cols': ['A', 'A', 'A', 'A', 'B', 'B', 'C', 'C', 'C', 'C', 'C'],
'col1': [2, 1, 1, 0, 3, 2, 3, None, 4, 4, 1],
'col2': [None, 'not null', 'also not null', 'not null either', 'not null', 'not null', None, None, None, None, None]
})
becomes
pd.DataFrame({'group_cols': ['A', 'A', 'A', 'A', 'B', 'B', 'C', 'C', 'C', 'C', 'C'],
'col1': [2.0, 1.0, 1.0, 0.0, 3.0, 2.0, 3.0, nan, 4.0, 4.0, 1.0],
'col2': [None,
'not null',
'also not null',
'not null either',
'not null',
'not null',
None,
None,
None,
None,
None],
'category': ['secondary',
'primary',
'secondary',
'secondary',
'primary',
'secondary',
'secondary',
'secondary',
'primary',
'secondary',
'secondary']})
But I am struggling to even start replicating this kind of logic in polars, let alone without making use of the discouraged map_groups.
You can use the following (thanks to jqurious for pointing out how this can be done with a single over
):
# %%
import polars as pl
from IPython.display import display
df = pl.DataFrame(
{
"group_cols": ["A", "A", "A", "A", "B", "B", "C", "C", "C", "C", "C"],
"col1": [2, 1, 1, 0, 3, 2, 3, None, 4, 4, 1],
"col2": [
None,
"not null",
"also not null",
"not null either",
"not null",
"not null",
None,
None,
None,
None,
None,
],
}
)
df = df.with_columns(
category=pl.when(pl.col("col2").is_null().all())
.then(
pl.when(pl.int_range(pl.len()) == pl.col("col1").arg_max())
.then(pl.lit("primary"))
.otherwise(pl.lit("secondary"))
)
.when(
pl.int_range(pl.len())
== pl.arg_where(
pl.col("col1") == pl.col("col1").filter(pl.col("col2").is_not_null()).max()
).first()
)
.then(pl.lit("primary"))
.otherwise(pl.lit("secondary"))
.over(["group_cols"])
)
with pl.Config(tbl_rows=-1):
display(df)
This acts first on groups, where col2
is entirely null, choosing the first max value. For the other groups, it takes the first max value where col2
is non null.
shape: (11, 4)
┌────────────┬──────┬─────────────────┬───────────┐
│ group_cols ┆ col1 ┆ col2 ┆ category │
│ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ i64 ┆ str ┆ str │
╞════════════╪══════╪═════════════════╪═══════════╡
│ A ┆ 2 ┆ null ┆ secondary │
│ A ┆ 1 ┆ not null ┆ primary │
│ A ┆ 1 ┆ also not null ┆ secondary │
│ A ┆ 0 ┆ not null either ┆ secondary │
│ B ┆ 3 ┆ not null ┆ primary │
│ B ┆ 2 ┆ not null ┆ secondary │
│ C ┆ 3 ┆ null ┆ secondary │
│ C ┆ null ┆ null ┆ secondary │
│ C ┆ 4 ┆ null ┆ primary │
│ C ┆ 4 ┆ null ┆ secondary │
│ C ┆ 1 ┆ null ┆ secondary │
└────────────┴──────┴─────────────────┴───────────┘