Search code examples
python-polars

How to filter a lazy dataframe based on the n-to-last element of a column?


My input is this :

import polars as pl

ldf = pl.LazyFrame({'CATEGORY': ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I'],
 'DATE': ['05/01/2023', '05/01/2023', '05/01/2023', '05/01/2023', '16/05/2023',
          '16/05/2023', '14/05/2023', '14/05/2023', '14/05/2023']})

print(ldf.collect())

shape: (9, 2)
┌──────────┬────────────┐
│ CATEGORY ┆ DATE       │
│ ---      ┆ ---        │
│ str      ┆ str        │
╞══════════╪════════════╡
│ A        ┆ 05/01/2023 │
│ B        ┆ 05/01/2023 │
│ C        ┆ 05/01/2023 │
│ D        ┆ 05/01/2023 │
│ E        ┆ 16/05/2023 │
│ F        ┆ 16/05/2023 │
│ G        ┆ 14/05/2023 │
│ H        ┆ 14/05/2023 │
│ I        ┆ 14/05/2023 │
└──────────┴────────────┘

I want to collect the data based on the n-to-last element (let it be the second, here 16/05/2023). The reason I can't hardcode it is that the parquet files I receive can have different dates and I'm only interested on the rows that correspond the second-to-last date.

A precision : I need to target the second-to-last group in the DATE (to be more precise).

To filter based on the last row (which belongs for sure to the last group), I can do pl.last('DATE') like below :

print(
    ldf
        .filter(pl.col('DATE') == pl.last('DATE'))
        .collect()
)

shape: (3, 2)
┌──────────┬────────────┐
│ CATEGORY ┆ DATE       │
│ ---      ┆ ---        │
│ str      ┆ str        │
╞══════════╪════════════╡
│ G        ┆ 14/05/2023 │
│ H        ┆ 14/05/2023 │
│ I        ┆ 14/05/2023 │
└──────────┴────────────┘

My expected output is this :

shape: (2, 2)
┌──────────┬────────────┐
│ CATEGORY ┆ DATE       │
│ ---      ┆ ---        │
│ str      ┆ str        │
╞══════════╪════════════╡
│ E        ┆ 16/05/2023 │
│ F        ┆ 16/05/2023 │
└──────────┴────────────┘

Can you guys tell me if it's possible to do that before using collect ?


Solution

  • To get the unique listing of dates in the column while maintaining order, unique provides us just that with a keyword argument.

    From there, we just need to extract the element, here at index -2. Unfortunately Expr.get does not support negative indexing at this time, but we can work around this by converting the expression to a list with implode(), then use list.get which does support negative indexing.

    Overall:

    ldf.filter(pl.col('DATE') ==
        pl.col('DATE')
        .unique(maintain_order=True)
        .implode()
        .list.get(-2)
    ).collect()
    
    shape: (2, 2)
    ┌──────────┬────────────┐
    │ CATEGORY ┆ DATE       │
    │ ---      ┆ ---        │
    │ str      ┆ str        │
    ╞══════════╪════════════╡
    │ E        ┆ 16/05/2023 │
    │ F        ┆ 16/05/2023 │
    └──────────┴────────────┘
    

    Addendums: Assumes dates don't repeat other than the one contiguous block they appear in. If this is not true, probably some type of group_by and transforming the DataFrame back is needed.

    While this works, it does use the whole DataFrame to calculate all unique values, when ultimately we only want the n-th unique value, at second to last it feels like some technique starting from the bottom of the dataframe (or top if n was small at the top) would scale better.