Starting with this DataFrame:
df_1 = pl.DataFrame({
'name': ['Alpha', 'Alpha', 'Alpha', 'Alpha', 'Alpha'],
'index': [0, 3, 4, 7, 9],
'limit': [12, 18, 11, 5, 9],
'price': [10, 15, 12, 8, 11]
})
┌───────┬───────┬───────┬───────┐
│ name ┆ index ┆ limit ┆ price │
│ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ i64 ┆ i64 ┆ i64 │
╞═══════╪═══════╪═══════╪═══════╡
│ Alpha ┆ 0 ┆ 12 ┆ 10 │
│ Alpha ┆ 3 ┆ 18 ┆ 15 │
│ Alpha ┆ 4 ┆ 11 ┆ 12 │
│ Alpha ┆ 7 ┆ 5 ┆ 8 │
│ Alpha ┆ 9 ┆ 9 ┆ 11 │
└───────┴───────┴───────┴───────┘
I need to add a new column to tell me at which index (greater than the current one) the price is equal or higher than the current limit.
With this example above, the expected output is:
┌───────┬───────┬───────┬───────┬───────────┐
│ name ┆ index ┆ limit ┆ price ┆ min_index │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ i64 ┆ i64 ┆ i64 ┆ i64 │
╞═══════╪═══════╪═══════╪═══════╪═══════════╡
│ Alpha ┆ 0 ┆ 12 ┆ 10 ┆ 3 │
│ Alpha ┆ 3 ┆ 18 ┆ 15 ┆ null │
│ Alpha ┆ 4 ┆ 11 ┆ 12 ┆ 9 │
│ Alpha ┆ 7 ┆ 5 ┆ 8 ┆ 9 │
│ Alpha ┆ 9 ┆ 9 ┆ 11 ┆ null │
└───────┴───────┴───────┴───────┴───────────┘
Explaining the "min_index" column results:
My solution is shown below - but what would be a neat Polars way of doing it? I was able to solve it in 8 steps, but I'm sure there is a more effective way of doing it.
# Import Polars.
import polars as pl
# Create a sample DataFrame.
df_1 = pl.DataFrame({
'name': ['Alpha', 'Alpha', 'Alpha', 'Alpha', 'Alpha'],
'index': [0, 3, 4, 7, 9],
'limit': [12, 18, 11, 5, 9],
'price': [10, 15, 12, 8, 11]
})
# Group by name, so that we can vertically stack all row's values into a single list.
df_2 = df_1.group_by('name').agg(pl.all())
# Put the lists with the original DataFrame.
df_3 = df_1.join(
other=df_2,
on='name',
suffix='_list'
)
# Explode the dataframe to long format by exploding the given columns.
df_3 = df_3.explode([
'index_list',
'limit_list',
'price_list',
])
# Filter the DataFrame for the condition we want.
df_3 = df_3.filter(
(pl.col('index_list') > pl.col('index')) &
(pl.col('price_list') >= pl.col('limit'))
)
# Get the minimum index over the index column.
df_3 = df_3.with_columns(
pl.col('index_list').min().over('index').alias('min_index')
)
# Select only the relevant columns and drop duplicates.
df_3 = df_3.select(
pl.col(['index', 'min_index'])
).unique()
# Finally join the result.
df_final = df_1.join(
other=df_3,
on='index',
how='left'
)
print(df_final)
Option 1: df.join_where
(experimental)
out = (
df_1.join(
df_1
.join_where(
df_1.select('index', 'price'),
pl.col('index_right') > pl.col('index'),
pl.col('price_right') >= pl.col('limit')
)
.group_by('index')
.agg(
pl.col('index_right').min().alias('min_index')
),
on='index',
how='left'
)
)
Output:
shape: (5, 5)
┌───────┬───────┬───────┬───────┬───────────┐
│ name ┆ index ┆ limit ┆ price ┆ min_index │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ i64 ┆ i64 ┆ i64 ┆ i64 │
╞═══════╪═══════╪═══════╪═══════╪═══════════╡
│ Alpha ┆ 0 ┆ 12 ┆ 10 ┆ 3 │
│ Alpha ┆ 3 ┆ 18 ┆ 15 ┆ null │
│ Alpha ┆ 4 ┆ 11 ┆ 12 ┆ 9 │
│ Alpha ┆ 7 ┆ 5 ┆ 8 ┆ 9 │
│ Alpha ┆ 9 ┆ 9 ┆ 11 ┆ null │
└───────┴───────┴───────┴───────┴───────────┘
Explanation / Intermediates
df.join_where
and for other
use df.select
(note that you don't need 'limit'), adding the filter predicates.# df_1.join_where(...)
shape: (4, 6)
┌───────┬───────┬───────┬───────┬─────────────┬─────────────┐
│ name ┆ index ┆ limit ┆ price ┆ index_right ┆ price_right │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ i64 ┆ i64 ┆ i64 ┆ i64 ┆ i64 │
╞═══════╪═══════╪═══════╪═══════╪═════════════╪═════════════╡
│ Alpha ┆ 0 ┆ 12 ┆ 10 ┆ 3 ┆ 15 │
│ Alpha ┆ 0 ┆ 12 ┆ 10 ┆ 4 ┆ 12 │
│ Alpha ┆ 4 ┆ 11 ┆ 12 ┆ 9 ┆ 11 │
│ Alpha ┆ 7 ┆ 5 ┆ 8 ┆ 9 ┆ 11 │
└───────┴───────┴───────┴───────┴─────────────┴─────────────┘
df.group_by
to retrieve pl.Expr.min
per 'index'.# df_1.join_where(...).group_by('index').agg(...)
shape: (3, 2)
┌───────┬───────────┐
│ index ┆ min_index │
│ --- ┆ --- │
│ i64 ┆ i64 │
╞═══════╪═══════════╡
│ 0 ┆ 3 │
│ 7 ┆ 9 │
│ 4 ┆ 9 │
└───────┴───────────┘
df_1
with a left join.Option 2: df.join
with "cross" + df.filter
(Adding this option, since df.join_where
is experimental. This will be more expensive though.)
out2 = (
df_1.join(
df_1
.join(df_1.select('index', 'price'), how='cross')
.filter(
pl.col('index_right') > pl.col('index'),
pl.col('price_right') >= pl.col('limit')
)
.group_by('index')
.agg(
pl.col('index_right').min().alias('min_index')
),
on='index',
how='left'
)
)
out2.equals(out)
# True