I'm trying to aggregate some rows in my dataframe with a list[str]
column. For each index I need the intersection of all the lists in the group. Not sure if I'm just overthinking it but I can't provide a solution right now. Any help please?
import polars as pl
input_df = pl.DataFrame(
{"idx": [1,1,2,2,3,3],
"values": [["A", "B"], ["B", "C"], ["A", "B"], ["B", "C"], ["A", "B"], ["B", "C"]]
}
)
output_df = input_df.agg(...)
>>> input_df
shape: (6, 2)
┌─────┬────────────┐
│ idx ┆ values │
│ --- ┆ --- │
│ i64 ┆ list[str] │
╞═════╪════════════╡
│ 1 ┆ ["A", "B"] │
│ 1 ┆ ["B", "C"] │
│ 2 ┆ ["A", "B"] │
│ 2 ┆ ["B", "C"] │
│ 3 ┆ ["A", "B"] │
│ 3 ┆ ["B", "C"] │
└─────┴────────────┘
>>> output_df # Expected output
shape: (3, 2)
┌─────┬───────────┐
│ idx ┆ values │
│ --- ┆ --- │
│ i64 ┆ list[str] │
╞═════╪═══════════╡
│ 1 ┆ ["B"] │
│ 2 ┆ ["B"] │
│ 3 ┆ ["B"] │
└─────┴───────────┘
I've tried some stuff without success
>>> input_df.group_by("idx").agg(
pl.reduce(function=lambda acc, x: acc.list.set_intersection(x),
exprs=pl.col("values"))
)
shape: (3, 2)
┌─────┬──────────────────────────┐
│ idx ┆ values │
│ --- ┆ --- │
│ i64 ┆ list[list[str]] │
╞═════╪══════════════════════════╡
│ 1 ┆ [["A", "B"], ["B", "C"]] │
│ 2 ┆ [["A", "B"], ["B", "C"]] │
│ 3 ┆ [["A", "B"], ["B", "C"]] │
└─────┴──────────────────────────┘
Another one
>>> input_df.group_by("idx").agg(
pl.reduce(function=lambda acc, x: acc.list.set_intersection(x),
exprs=pl.col("values").explode())
)
shape: (3, 2)
┌─────┬───────────────────┐
│ idx ┆ values │
│ --- ┆ --- │
│ i64 ┆ list[str] │
╞═════╪═══════════════════╡
│ 3 ┆ ["A", "B", … "C"] │
│ 2 ┆ ["A", "B", … "C"] │
│ 1 ┆ ["A", "B", … "C"] │
└─────┴───────────────────┘
I'm not sure if this is as simple as it may seem.
You could get rid of the lists and use "regular" Polars functionality.
One way to check if a value is contained in each row of the idx
group is to count the number of unique (distinct) row numbers per idx, values
group.
(df.with_columns(group_len = pl.count().over("idx"))
.with_row_count()
.explode("values")
.with_columns(num_rows =
pl.n_unique("row_nr").over("idx", "values")
)
)
shape: (12, 5)
┌────────┬─────┬────────┬───────────┬──────────┐
│ row_nr ┆ idx ┆ values ┆ group_len ┆ num_rows │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ u32 ┆ i64 ┆ str ┆ u32 ┆ u32 │
╞════════╪═════╪════════╪═══════════╪══════════╡
│ 0 ┆ 1 ┆ A ┆ 2 ┆ 1 │
│ 0 ┆ 1 ┆ B ┆ 2 ┆ 2 │ # row_nr = [0, 1]
│ 1 ┆ 1 ┆ B ┆ 2 ┆ 2 │
│ 1 ┆ 1 ┆ C ┆ 2 ┆ 1 │
│ 2 ┆ 2 ┆ A ┆ 2 ┆ 1 │
│ 2 ┆ 2 ┆ B ┆ 2 ┆ 2 │ # row_nr = [2, 3]
│ 3 ┆ 2 ┆ B ┆ 2 ┆ 2 │
│ 3 ┆ 2 ┆ C ┆ 2 ┆ 1 │
│ 4 ┆ 3 ┆ A ┆ 2 ┆ 1 │
│ 4 ┆ 3 ┆ B ┆ 2 ┆ 2 │ # row_nr = [4, 5]
│ 5 ┆ 3 ┆ B ┆ 2 ┆ 2 │
│ 5 ┆ 3 ┆ C ┆ 2 ┆ 1 │
└────────┴─────┴────────┴───────────┴──────────┘
You can filter those, and build the list of results with .group_by
(df.with_columns(group_len = pl.count().over("idx"))
.with_row_count()
.explode("values")
.filter(
pl.n_unique("row_nr").over("idx", "values")
== pl.col("group_len")
)
.group_by("idx", maintain_order=True)
.agg(pl.col("values").unique())
)
shape: (3, 2)
┌─────┬───────────┐
│ idx ┆ values │
│ --- ┆ --- │
│ i64 ┆ list[str] │
╞═════╪═══════════╡
│ 1 ┆ ["B"] │
│ 2 ┆ ["B"] │
│ 3 ┆ ["B"] │
└─────┴───────────┘