Search code examples
pythonpython-polars

Filter and aggregate based on other dataframe


Say I have

df1 = pl.DataFrame({'start': [1., 2., 4.], 'end': [2., 4., 6.]})
df2 = pl.DataFrame({'idx': [1., 1.7, 2.3, 2.5, 3., 4.], 'values': [3, 1, 4, 2, 3, 5]})

They look like this:

In [8]: df1
Out[8]:
shape: (3, 2)
┌───────┬─────┐
│ start ┆ end │
│ ---   ┆ --- │
│ f64   ┆ f64 │
╞═══════╪═════╡
│ 1.0   ┆ 2.0 │
│ 2.0   ┆ 4.0 │
│ 4.0   ┆ 6.0 │
└───────┴─────┘

In [9]: df2
Out[9]:
shape: (6, 2)
┌─────┬────────┐
│ idx ┆ values │
│ --- ┆ ---    │
│ f64 ┆ i64    │
╞═════╪════════╡
│ 1.0 ┆ 3      │
│ 1.7 ┆ 1      │
│ 2.3 ┆ 4      │
│ 2.5 ┆ 2      │
│ 3.0 ┆ 3      │
│ 4.0 ┆ 5      │
└─────┴────────┘

I would like to end up with something like this:

In [6]: expected = pl.DataFrame({
   ...:     'start': [1., 2., 4.],
   ...:     'end': [2., 4.5, 6.],
   ...:     'sum_values': [4, 9, 5]
   ...: })

In [7]: expected
Out[7]:
shape: (3, 3)
┌───────┬─────┬────────────┐
│ start ┆ end ┆ sum_values │
│ ---   ┆ --- ┆ ---        │
│ f64   ┆ f64 ┆ i64        │
╞═══════╪═════╪════════════╡
│ 1.0   ┆ 2.0 ┆ 4          │
│ 2.0   ┆ 4.5 ┆ 9          │
│ 4.0   ┆ 6.0 ┆ 5          │
└───────┴─────┴────────────┘

Here's an inefficient way of doing it I came up with, using map_rows:

(
    df1.with_columns(
        df1.map_rows(
            lambda row: df2.filter(
                pl.col("idx").is_between(row[0], row[1], closed="left")
            )["values"].sum()
        )["map"].alias("sum_values")
    )
)

It gives the correct output, but because it uses map_rows and a Python lambda function, it's not as performant as it could be.

Is there a way to write this using polars native expressions API?


Solution

  • Update: "non-equi" joins have since been added to Polars.

    (df1
      .join_where(df2,
         pl.col.idx >= pl.col.start,
         pl.col.idx < pl.col.end
      )
      .group_by("start", "end")
      .agg(pl.col.values.sum())
    )
    

    Original answer

    I'm not sure if there is another way apart from a cross join:

    (df1.join(df2, how="cross")
        .filter(pl.col.idx.is_between("start", "end", closed="left"))
        .group_by("start", "end")
        .agg(pl.col.values.sum())
    )
    
    shape: (3, 3)
    ┌───────┬─────┬────────┐
    │ start ┆ end ┆ values │
    │ ---   ┆ --- ┆ ---    │
    │ f64   ┆ f64 ┆ i64    │
    ╞═══════╪═════╪════════╡
    │ 1.0   ┆ 2.0 ┆ 4      │
    │ 4.0   ┆ 6.0 ┆ 5      │
    │ 2.0   ┆ 4.0 ┆ 9      │
    └───────┴─────┴────────┘