Search code examples
pythonpython-polars

Select the first and last row per group in Polars dataframe


I'm trying to use polars dataframe where I would like to select the first and last row per group. Here is a simple example selecting the first row per group:

import polars as pl

df = pl.DataFrame(
    {
        "a": [1, 2, 2, 3, 4, 5],
        "b": [0.5, 0.5, 4, 10, 14, 13],
        "c": [True, True, True, False, False, True],
        "d": ["Apple", "Apple", "Apple", "Banana", "Banana", "Banana"],
    }
)
result = df.group_by("d", maintain_order=True).first()
print(result)

Output:

shape: (2, 4)
┌────────┬─────┬──────┬───────┐
│ d      ┆ a   ┆ b    ┆ c     │
│ ---    ┆ --- ┆ ---  ┆ ---   │
│ str    ┆ i64 ┆ f64  ┆ bool  │
╞════════╪═════╪══════╪═══════╡
│ Apple  ┆ 1   ┆ 0.5  ┆ true  │
│ Banana ┆ 3   ┆ 10.0 ┆ false │
└────────┴─────┴──────┴───────┘

This works good and we can use .last to do it for the last row. But how can we combine these in one group_by?


Solution

  • As columns

    You could use agg, you will have to add a suffix (or prefix) to differentiate the columns names:

    result = (df.group_by('d', maintain_order=True)
                .agg(pl.all().first().name.suffix('_first'),
                     pl.all().last().name.suffix('_last'))
             )
    

    Output:

    ┌────────┬─────────┬─────────┬─────────┬────────┬────────┬────────┐
    │ d      ┆ a_first ┆ b_first ┆ c_first ┆ a_last ┆ b_last ┆ c_last │
    │ ---    ┆ ---     ┆ ---     ┆ ---     ┆ ---    ┆ ---    ┆ ---    │
    │ str    ┆ i64     ┆ f64     ┆ bool    ┆ i64    ┆ f64    ┆ bool   │
    ╞════════╪═════════╪═════════╪═════════╪════════╪════════╪════════╡
    │ Apple  ┆ 1       ┆ 0.5     ┆ true    ┆ 2      ┆ 4.0    ┆ true   │
    │ Banana ┆ 3       ┆ 10.0    ┆ false   ┆ 5      ┆ 13.0   ┆ true   │
    └────────┴─────────┴─────────┴─────────┴────────┴────────┴────────┘
    

    As rows

    If you want multiple rows, then you would need to concat:

    g = df.group_by('d', maintain_order=True)
    
    result = pl.concat([g.first(), g.last()]).sort(by='d', maintain_order=True)
    

    Output:

    ┌────────┬─────┬──────┬───────┐
    │ d      ┆ a   ┆ b    ┆ c     │
    │ ---    ┆ --- ┆ ---  ┆ ---   │
    │ str    ┆ i64 ┆ f64  ┆ bool  │
    ╞════════╪═════╪══════╪═══════╡
    │ Apple  ┆ 1   ┆ 0.5  ┆ true  │
    │ Apple  ┆ 2   ┆ 4.0  ┆ true  │
    │ Banana ┆ 3   ┆ 10.0 ┆ false │
    │ Banana ┆ 5   ┆ 13.0 ┆ true  │
    └────────┴─────┴──────┴───────┘
    

    Or using filter with int_range+over:

    result = df.filter((pl.int_range(pl.len()).over('d') == 0)
                      |(pl.int_range(pl.len(), 0, -1).over('d') == 1)
                      )
    

    Output:

    ┌─────┬──────┬───────┬────────┐
    │ a   ┆ b    ┆ c     ┆ d      │
    │ --- ┆ ---  ┆ ---   ┆ ---    │
    │ i64 ┆ f64  ┆ bool  ┆ str    │
    ╞═════╪══════╪═══════╪════════╡
    │ 1   ┆ 0.5  ┆ true  ┆ Apple  │
    │ 2   ┆ 4.0  ┆ true  ┆ Apple  │
    │ 3   ┆ 10.0 ┆ false ┆ Banana │
    │ 5   ┆ 13.0 ┆ true  ┆ Banana │
    └─────┴──────┴───────┴────────┘