I am working with a Polars DataFrame and need to perform computations on each row using values from other rows. Currently, I am using the map_elements method, but it is not efficient.
In the following example, I add two new columns to a DataFrame:
Here is my current implementation:
import polars as pl
COL_VALUE = "value"
def fun_sum_lower(current_row, df):
tmp_df = df.filter(pl.col(COL_VALUE) < current_row[COL_VALUE])
sum_lower = tmp_df.select(pl.sum(COL_VALUE)).item()
return sum_lower
def fun_max_other(current_row, df):
tmp_df = df.filter(pl.col(COL_VALUE) != current_row[COL_VALUE])
max_other = tmp_df.select(pl.col(COL_VALUE)).max().item()
return max_other
if __name__ == '__main__':
df = pl.DataFrame({COL_VALUE: [3, 7, 1, 9, 4]})
df = df.with_columns(
pl.struct([COL_VALUE])
.map_elements(lambda row: fun_sum_lower(row, df), return_dtype=pl.Int64)
.alias("sum_lower")
)
df = df.with_columns(
pl.struct([COL_VALUE])
.map_elements(lambda row: fun_max_other(row, df), return_dtype=pl.Int64)
.alias("max_other")
)
print(df)
The output of the above code is:
shape: (5, 3)
┌───────┬───────────┬───────────┐
│ value ┆ sum_lower ┆ max_other │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ i64 │
╞═══════╪═══════════╪═══════════╡
│ 3 ┆ 1 ┆ 9 │
│ 7 ┆ 8 ┆ 9 │
│ 1 ┆ 0 ┆ 9 │
│ 9 ┆ 15 ┆ 7 │
│ 4 ┆ 4 ┆ 9 │
└───────┴───────────┴───────────┘
While this code works, it is not efficient due to the use of lambdas and row-wise operations.
Is there a more efficient way to achieve this in Polars, without using lambdas, iterating over rows, or running Python code?
I also tried using Polars methods: cum_sum
, group_by_dynamic
, and rolling
, but I don't think those can be used for this task.
For your specific use case you don't really need join, you can calculate values with window functions.
pl.Expr.shift()
to exclude current row.pl.Expr.cum_sum()
to calculate sum of all elements up to the current row.pl.Expr.max()
to calculate max.pl.Expr.bottom_k()
to calculate 2 largest elements so then we can take pl.Expr.min()
as second largest.(
df
.sort("value")
.with_columns(
sum_lower = pl.col.value.shift(1).cum_sum().fill_null(0),
max_other =
pl.when(pl.col.value.max() != pl.col.value)
.then(pl.col.value.max())
.otherwise(pl.col.value.bottom_k(2).min())
)
)
shape: (5, 3)
┌───────┬───────────┬───────────┐
│ value ┆ sum_lower ┆ max_other │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ i64 │
╞═══════╪═══════════╪═══════════╡
│ 1 ┆ 0 ┆ 9 │
│ 3 ┆ 1 ┆ 9 │
│ 4 ┆ 4 ┆ 9 │
│ 7 ┆ 8 ┆ 9 │
│ 9 ┆ 15 ┆ 7 │
└───────┴───────────┴───────────┘
You can also use pl.DataFrame.with_row_index()
to keep current order so you can revert to it at the end with pl.DataFrame.sort()
.
(
df.with_row_index()
.sort("value")
.with_columns(
sum_lower = pl.col.value.shift(1).cum_sum().fill_null(0),
max_other =
pl.when(pl.col.value.max() != pl.col.value)
.then(pl.col.value.max())
.otherwise(pl.col.value.bottom_k(2).min())
)
.sort("index")
.drop("index")
)
Another possible solution would be to use DuckDB integration with Polars.
Using window functions, getting advantage of excellent DuckDB windows framing options.
max(arg, n)
to calculate top 2 largest elements.import duckdb
duckdb.sql("""
select
d.value,
coalesce(sum(d.value) over(
order by d.value
rows unbounded preceding
exclude current row
), 0) as sum_lower,
max(d.value) over(
rows between unbounded preceding and unbounded following
exclude current row
) as max_other
from df as d
""").pl()
shape: (5, 3)
┌───────┬───────────────┬───────────┐
│ value ┆ sum_lower ┆ max_other │
│ --- ┆ --- ┆ --- │
│ i64 ┆ decimal[38,0] ┆ i64 │
╞═══════╪═══════════════╪═══════════╡
│ 1 ┆ 0 ┆ 9 │
│ 3 ┆ 1 ┆ 9 │
│ 4 ┆ 4 ┆ 9 │
│ 7 ┆ 8 ┆ 9 │
│ 9 ┆ 15 ┆ 7 │
└───────┴───────────────┴───────────┘
Or using lateral join
:
import duckdb
duckdb.sql("""
select
d.value,
coalesce(s.value, 0) as sum_lower,
m.value as max_other
from df as d,
lateral (select sum(t.value) as value from df as t where t.value < d.value) as s,
lateral (select max(t.value) as value from df as t where t.value != d.value) as m
""").pl()
shape: (5, 3)
┌───────┬───────────┬───────────┐
│ value ┆ sum_lower ┆ max_other │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ i64 │
╞═══════╪═══════════╪═══════════╡
│ 3 ┆ 1 ┆ 9 │
│ 7 ┆ 8 ┆ 9 │
│ 1 ┆ 0 ┆ 9 │
│ 9 ┆ 15 ┆ 7 │
│ 4 ┆ 4 ┆ 9 │
└───────┴───────────┴───────────┘
duplicate values
pure polars solution above works well if there're no duplicate values, but if there are, you can also work around it. Here're 2 examples depending on whether you want to keep original order or not:
# not keeping original order
(
df
.select(pl.col.value.value_counts()).unnest("value")
.sort("value")
.with_columns(
sum_lower = pl.col.value.shift(1).cum_sum().fill_null(0),
max_other =
pl.when(pl.col.value.max() != pl.col.value)
.then(pl.col.value.max())
.otherwise(pl.col.value.bottom_k(2).min()),
value = pl.col.value.repeat_by("count")
).drop("count").explode("value")
)
# keeping original order
(
df.with_row_index()
.group_by("value").agg("index")
.sort("value")
.with_columns(
sum_lower = pl.col.value.shift(1).cum_sum().fill_null(0),
max_other =
pl.when(pl.col.value.max() != pl.col.value)
.then(pl.col.value.max())
.otherwise(pl.col.value.bottom_k(2).min())
)
.explode("index")
.sort("index")
.drop("index")
)