Search code examples
pythonpython-polarspolars

Return two closest rows above and below a target value in Polars


I'm trying to figure out the most elegant way in Polars to find the two bracketing rows (first above and first below) a specific target. Essentially the Min > 0 & the Max < 0.

data = {
    "strike": [5,10,15,20,25,30],
    "target": [16] * 6,
}

df = (pl.DataFrame(data)
        .with_columns(
            diff = pl.col('strike') - pl.col('target')))

shape: (6, 3)
┌────────┬────────┬──────┐
│ strike ┆ target ┆ diff │
│ ---    ┆ ---    ┆ ---  │
│ i64    ┆ i64    ┆ i64  │
╞════════╪════════╪══════╡
│ 5      ┆ 16     ┆ -11  │
│ 10     ┆ 16     ┆ -6   │
│ 15     ┆ 16     ┆ -1   │
│ 20     ┆ 16     ┆ 4    │
│ 25     ┆ 16     ┆ 9    │
│ 30     ┆ 16     ┆ 14   │
└────────┴────────┴──────┘

This is what I'm trying to arrive at.

shape: (2, 3)
┌────────┬────────┬──────┐
│ strike ┆ target ┆ diff │
│ ---    ┆ ---    ┆ ---  │
│ i64    ┆ i64    ┆ i64  │
╞════════╪════════╪══════╡
│ 15     ┆ 16     ┆ -1   │
│ 20     ┆ 16     ┆ 4    │
└────────┴────────┴──────┘

I am able to do it through two separate filter operations but I can't seem to string these together and would like to avoid having to vstack the two individual results back together into a single data frame if possible.

df1 = (pl.DataFrame(data)
        .with_columns(
            diff = pl.col('strike') - pl.col('target'))
        .filter((pl.col('diff') > 0)).min())

shape: (1, 3)
┌────────┬────────┬──────┐
│ strike ┆ target ┆ diff │
│ ---    ┆ ---    ┆ ---  │
│ i64    ┆ i64    ┆ i64  │
╞════════╪════════╪══════╡
│ 20     ┆ 16     ┆ 4    │
└────────┴────────┴──────┘

df2 = (pl.DataFrame(data)
        .with_columns(
            diff = pl.col('strike') - pl.col('target'))
        .filter((pl.col('diff') < 0)).max())

shape: (1, 3)
┌────────┬────────┬──────┐
│ strike ┆ target ┆ diff │
│ ---    ┆ ---    ┆ ---  │
│ i64    ┆ i64    ┆ i64  │
╞════════╪════════╪══════╡
│ 15     ┆ 16     ┆ -1   │
└────────┴────────┴──────┘


Solution

  • data = {"strike": [5,10,15,20,25,30]}
    
    target = 16
    diff = pl.col("strike") - target
    
    (
        pl.DataFrame(data)
        # diff and target can be added as columns, but this is not actually needed
        # .with_columns(target=target, diff=diff)
        # filter for where the diff equals the min above 0, or the max below 0
        .filter(
            (diff == diff.filter(diff > 0).min()) |
            (diff == diff.filter(diff < 0).max())
        )
    )
    # shape: (2, 1)
    # ┌────────┐
    # │ strike │
    # │ ---    │
    # │ i64    │
    # ╞════════╡
    # │ 15     │
    # │ 20     │
    # └────────┘