Search code examples
pythonpython-polarspolars

Filtering from index and comparing row value with all values in column


Starting with this DataFrame:

df_1 = pl.DataFrame({
    'name': ['Alpha', 'Alpha', 'Alpha', 'Alpha', 'Alpha'],
    'index': [0, 3, 4, 7, 9],
    'limit': [12, 18, 11, 5, 9],
    'price': [10, 15, 12, 8, 11]
})

┌───────┬───────┬───────┬───────┐
│ name  ┆ index ┆ limit ┆ price │
│ ---   ┆   --- ┆   --- ┆   --- │
│ str   ┆   i64 ┆   i64 ┆   i64 │
╞═══════╪═══════╪═══════╪═══════╡
│ Alpha ┆     0 ┆    12 ┆    10 │
│ Alpha ┆     3 ┆    18 ┆    15 │
│ Alpha ┆     4 ┆    11 ┆    12 │
│ Alpha ┆     7 ┆     5 ┆     8 │
│ Alpha ┆     9 ┆     9 ┆    11 │
└───────┴───────┴───────┴───────┘

I need to add a new column to tell me at which index (greater than the current one) the price is equal or higher than the current limit.

With this example above, the expected output is:

┌───────┬───────┬───────┬───────┬───────────┐
│ name  ┆ index ┆ limit ┆ price ┆ min_index │
│ ---   ┆   --- ┆   --- ┆   --- ┆       --- │
│ str   ┆   i64 ┆   i64 ┆   i64 ┆       i64 │
╞═══════╪═══════╪═══════╪═══════╪═══════════╡
│ Alpha ┆     0 ┆    12 ┆    10 ┆         3 │
│ Alpha ┆     3 ┆    18 ┆    15 ┆      null │
│ Alpha ┆     4 ┆    11 ┆    12 ┆         9 │
│ Alpha ┆     7 ┆     5 ┆     8 ┆         9 │
│ Alpha ┆     9 ┆     9 ┆    11 ┆      null │
└───────┴───────┴───────┴───────┴───────────┘

Explaining the "min_index" column results:

  • 1st row, where the limit is 12: from the 2nd row onwards, the minimum index whose price is equal or greater than 12 is 3.
  • 2nd row, where the limit is 18: from the 3rd row onwards, there is no index whose price is equal or greater than 18.
  • 3rd row, where the limit is 11: from the 4th row onwards, the minimum index whose price is equal or greater than 11 is 9.
  • 4th row, where the limit is 5: from the 5th row onwards, the minimum index whose price is equal or greater than 5 is 9.
  • 5th row, where the limit is 9: as this is the last row, there is no further index whose price is equal or greater than 9.

My solution is shown below - but what would be a neat Polars way of doing it? I was able to solve it in 8 steps, but I'm sure there is a more effective way of doing it.

# Import Polars.
import polars as pl

# Create a sample DataFrame.
df_1 = pl.DataFrame({
    'name': ['Alpha', 'Alpha', 'Alpha', 'Alpha', 'Alpha'],
    'index': [0, 3, 4, 7, 9],
    'limit': [12, 18, 11, 5, 9],
    'price': [10, 15, 12, 8, 11]
})

# Group by name, so that we can vertically stack all row's values into a single list.
df_2 = df_1.group_by('name').agg(pl.all())

# Put the lists with the original DataFrame.
df_3 = df_1.join(
    other=df_2,
    on='name',
    suffix='_list'
)

# Explode the dataframe to long format by exploding the given columns.
df_3 = df_3.explode([
    'index_list',
    'limit_list',
    'price_list',
])

# Filter the DataFrame for the condition we want.
df_3 = df_3.filter(
    (pl.col('index_list') > pl.col('index')) &
    (pl.col('price_list') >= pl.col('limit'))
)

# Get the minimum index over the index column.
df_3 = df_3.with_columns(
    pl.col('index_list').min().over('index').alias('min_index')
)

# Select only the relevant columns and drop duplicates.
df_3 = df_3.select(
    pl.col(['index', 'min_index'])
).unique()

# Finally join the result.
df_final = df_1.join(
    other=df_3,
    on='index',
    how='left'
)

print(df_final)

Solution

  • Option 1: df.join_where (experimental)

    out = (
        df_1.join(
            df_1
            .join_where(
                df_1.select('index', 'price'),
                pl.col('index_right') > pl.col('index'),
                pl.col('price_right') >= pl.col('limit')
            )
            .group_by('index')
            .agg(
                pl.col('index_right').min().alias('min_index')
                ),
            on='index',
            how='left'
        )
    )
    

    Output:

    shape: (5, 5)
    ┌───────┬───────┬───────┬───────┬───────────┐
    │ name  ┆ index ┆ limit ┆ price ┆ min_index │
    │ ---   ┆ ---   ┆ ---   ┆ ---   ┆ ---       │
    │ str   ┆ i64   ┆ i64   ┆ i64   ┆ i64       │
    ╞═══════╪═══════╪═══════╪═══════╪═══════════╡
    │ Alpha ┆ 0     ┆ 12    ┆ 10    ┆ 3         │
    │ Alpha ┆ 3     ┆ 18    ┆ 15    ┆ null      │
    │ Alpha ┆ 4     ┆ 11    ┆ 12    ┆ 9         │
    │ Alpha ┆ 7     ┆ 5     ┆ 8     ┆ 9         │
    │ Alpha ┆ 9     ┆ 9     ┆ 11    ┆ null      │
    └───────┴───────┴───────┴───────┴───────────┘
    

    Explanation / Intermediates

    • Use df.join_where and for other use df.select (note that you don't need 'limit'), adding the filter predicates.
    # df_1.join_where(...)
    
    shape: (4, 6)
    ┌───────┬───────┬───────┬───────┬─────────────┬─────────────┐
    │ name  ┆ index ┆ limit ┆ price ┆ index_right ┆ price_right │
    │ ---   ┆ ---   ┆ ---   ┆ ---   ┆ ---         ┆ ---         │
    │ str   ┆ i64   ┆ i64   ┆ i64   ┆ i64         ┆ i64         │
    ╞═══════╪═══════╪═══════╪═══════╪═════════════╪═════════════╡
    │ Alpha ┆ 0     ┆ 12    ┆ 10    ┆ 3           ┆ 15          │
    │ Alpha ┆ 0     ┆ 12    ┆ 10    ┆ 4           ┆ 12          │
    │ Alpha ┆ 4     ┆ 11    ┆ 12    ┆ 9           ┆ 11          │
    │ Alpha ┆ 7     ┆ 5     ┆ 8     ┆ 9           ┆ 11          │
    └───────┴───────┴───────┴───────┴─────────────┴─────────────┘
    
    # df_1.join_where(...).group_by('index').agg(...)
    
    shape: (3, 2)
    ┌───────┬───────────┐
    │ index ┆ min_index │
    │ ---   ┆ ---       │
    │ i64   ┆ i64       │
    ╞═══════╪═══════════╡
    │ 0     ┆ 3         │
    │ 7     ┆ 9         │
    │ 4     ┆ 9         │
    └───────┴───────────┘
    
    • The result we add to df_1 with a left join.

    Option 2: df.join with "cross" + df.filter

    (Adding this option, since df.join_where is experimental. This will be more expensive though.)

    out2 = (
        df_1.join(
            df_1
            .join(df_1.select('index', 'price'), how='cross')
            .filter(
                pl.col('index_right') > pl.col('index'),
                pl.col('price_right') >= pl.col('limit')
            )
            .group_by('index')
            .agg(
                pl.col('index_right').min().alias('min_index')
            ),
            on='index',
            how='left'
        )
    )
    
    out2.equals(out)
    # True