Search code examples
python-polars

Enumerate each group


Starting with

df = pl.DataFrame({'group': [1, 1, 1, 3, 3, 3, 4, 4]})

how can I get a column which numbers the 'group' column?

Here's what df looks like:

shape: (8, 1)
┌───────┐
│ group │
│ ---   │
│ i64   │
╞═══════╡
│ 1     │
│ 1     │
│ 1     │
│ 3     │
│ 3     │
│ 3     │
│ 4     │
│ 4     │
└───────┘

and here's my expected output:

shape: (8, 2)
┌───────┬─────────┐
│ group ┆ group_i │
│ ---   ┆ ---     │
│ i64   ┆ i64     │
╞═══════╪═════════╡
│ 1     ┆ 0       │
│ 1     ┆ 0       │
│ 1     ┆ 0       │
│ 3     ┆ 1       │
│ 3     ┆ 1       │
│ 3     ┆ 1       │
│ 4     ┆ 2       │
│ 4     ┆ 2       │
└───────┴─────────┘

Here's one way I came up with, it just feels a bit complex for this task...is there a simpler way?

df.with_columns(((pl.col('group')!=pl.col('group').shift()).cast(pl.Int64).cum_sum()-1).alias('group_i'))

Solution

  • I think the terms come from SQL:

    It could be said you want to .rank() your data.

    In particular, a "dense ranking".

    df.with_columns(pl.col("group").alias("group_i").rank("dense") - 1)
    
    shape: (8, 2)
    ┌───────┬─────────┐
    │ group ┆ group_i │
    │ ---   ┆ ---     │
    │ i64   ┆ u32     │
    ╞═══════╪═════════╡
    │ 1     ┆ 0       │
    │ 1     ┆ 0       │
    │ 1     ┆ 0       │
    │ 3     ┆ 1       │
    │ 3     ┆ 1       │
    │ 3     ┆ 1       │
    │ 4     ┆ 2       │
    │ 4     ┆ 2       │
    └───────┴─────────┘