I have a data set with multiple records on each individual - one record for each time period. Where an individual is missing a record for a time period, I need to remove any later records for that individual. So given an example dataset like this:
import polars as pl
df = pl.DataFrame({'Id': [1,1,2,2,2,2,3,3,4,4,4,5,5,5,6,6,6],
'Age': [1,4,1,2,3,4,1,2,1,2,3,1,2,4,2,3,4],
'Value': [1,142,4,73,109,145,6,72,-8,67,102,-1,72,150,72,111,149]})
df
shape: (17, 3)
┌─────┬─────┬───────┐
│ Id ┆ Age ┆ Value │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ i64 │
╞═════╪═════╪═══════╡
│ 1 ┆ 1 ┆ 1 │
│ 1 ┆ 4 ┆ 142 │
│ 2 ┆ 1 ┆ 4 │
│ 2 ┆ 2 ┆ 73 │
│ 2 ┆ 3 ┆ 109 │
│ 2 ┆ 4 ┆ 145 │
│ 3 ┆ 1 ┆ 6 │
│ 3 ┆ 2 ┆ 72 │
│ 4 ┆ 1 ┆ -8 │
│ 4 ┆ 2 ┆ 67 │
│ 4 ┆ 3 ┆ 102 │
│ 5 ┆ 1 ┆ -1 │
│ 5 ┆ 2 ┆ 72 │
│ 5 ┆ 4 ┆ 150 │
│ 6 ┆ 2 ┆ 72 │
│ 6 ┆ 3 ┆ 111 │
│ 6 ┆ 4 ┆ 149 │
└─────┴─────┴───────┘
I need to filter it as follows:
shape: (12, 4)
┌─────┬─────┬───────┬──────┐
│ Id ┆ Age ┆ Value ┆ Keep │
│ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ i64 ┆ bool │
╞═════╪═════╪═══════╪══════╡
│ 1 ┆ 1 ┆ 1 ┆ true │
│ 2 ┆ 1 ┆ 4 ┆ true │
│ 2 ┆ 2 ┆ 73 ┆ true │
│ 2 ┆ 3 ┆ 109 ┆ true │
│ 2 ┆ 4 ┆ 145 ┆ true │
│ 3 ┆ 1 ┆ 6 ┆ true │
│ 3 ┆ 2 ┆ 72 ┆ true │
│ 4 ┆ 1 ┆ -8 ┆ true │
│ 4 ┆ 2 ┆ 67 ┆ true │
│ 4 ┆ 3 ┆ 102 ┆ true │
│ 5 ┆ 1 ┆ -1 ┆ true │
│ 5 ┆ 2 ┆ 72 ┆ true │
└─────┴─────┴───────┴──────┘
So an individual with an age record profile of 1,3,4
would end up with only the 1
record. An individual like Id 6 with an age record profile of 2,3,4
would end up with no records after filtering.
I can achieve this using the approach below, however when the data set contains millions of individuals, the code appears not to run in parallel and performance is very slow (The steps prior to the final filter
expression complete in ~22 seconds on a data set with 16.5 million records, the last filter
expression takes another 12.5 minutes to complete). Is there an alternative approach that will not be single-threaded, or an adjustment of the code below to achieve that?
df2 = (
df.sort(by=["Id","Age"])
.with_columns(
((pl.col("Age").diff(1).fill_null(pl.col("Age") == 1) == 1))
.over("Id")
.alias("Keep")
)
.filter(
(pl.col("Keep").cum_prod() == 1).over("Id")
)
)
I propose the following (revised) code:
df2 = df.filter(pl.col('Age').rank().over('Id') == pl.col('Age'))
This code yields the following result on your test dataset:
shape: (12, 3)
┌─────┬─────┬───────┐
│ Id ┆ Age ┆ Value │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ i64 │
╞═════╪═════╪═══════╡
│ 1 ┆ 1 ┆ 1 │
│ 2 ┆ 1 ┆ 4 │
│ 2 ┆ 2 ┆ 73 │
│ 2 ┆ 3 ┆ 109 │
│ 2 ┆ 4 ┆ 145 │
│ 3 ┆ 1 ┆ 6 │
│ 3 ┆ 2 ┆ 72 │
│ 4 ┆ 1 ┆ -8 │
│ 4 ┆ 2 ┆ 67 │
│ 4 ┆ 3 ┆ 102 │
│ 5 ┆ 1 ┆ -1 │
│ 5 ┆ 2 ┆ 72 │
└─────┴─────┴───────┘
Basically, when an Age
is skipped (for a particular Id
), the rank
of the Age
falls out of step with the Age
variable itself, and remains out of step for all higher Age
values for that Id
.
This code has several advantages over my prior answer. It is more concise, it's far easier to follow, and best of all ... it makes excellent use of the Polars API, particularly the over
window function.
Even if this code benchmarks slightly slower in the upcoming Polars release, I recommend it for the reasons above.
ok, wow wow wow ... I just downloaded the newly released Polars (0.13.15), and re-benchmarked the code on my machine with the 17 million records generated as in my prior answer.
The results?
And from watching the htop
command while the code runs, it's clear that the newly released Polars code utilized all 64 logical cores on my machine. Massively parallel.
Impressive!