Search code examples
pythondataframelistaggregationpython-polars

Aggregate column with list of string with intersection of the elements with Polars


I'm trying to aggregate some rows in my dataframe with a list[str] column. For each index I need the intersection of all the lists in the group. Not sure if I'm just overthinking it but I can't provide a solution right now. Any help please?

import polars as pl    
input_df = pl.DataFrame(
   {"idx": [1,1,2,2,3,3], 
    "values": [["A", "B"], ["B", "C"], ["A", "B"], ["B", "C"], ["A", "B"], ["B", "C"]]
   }
)

output_df = input_df.agg(...)

>>> input_df
shape: (6, 2)
┌─────┬────────────┐
│ idx ┆ values     │
│ --- ┆ ---        │
│ i64 ┆ list[str]  │
╞═════╪════════════╡
│ 1   ┆ ["A", "B"] │
│ 1   ┆ ["B", "C"] │
│ 2   ┆ ["A", "B"] │
│ 2   ┆ ["B", "C"] │
│ 3   ┆ ["A", "B"] │
│ 3   ┆ ["B", "C"] │
└─────┴────────────┘
>>> output_df # Expected output
shape: (3, 2)
┌─────┬───────────┐
│ idx ┆ values    │
│ --- ┆ ---       │
│ i64 ┆ list[str] │
╞═════╪═══════════╡
│ 1   ┆ ["B"]     │
│ 2   ┆ ["B"]     │
│ 3   ┆ ["B"]     │
└─────┴───────────┘

I've tried some stuff without success

>>> input_df.group_by("idx").agg(
  pl.reduce(function=lambda acc, x: acc.list.set_intersection(x), 
     exprs=pl.col("values"))
)
shape: (3, 2)
┌─────┬──────────────────────────┐
│ idx ┆ values                   │
│ --- ┆ ---                      │
│ i64 ┆ list[list[str]]          │
╞═════╪══════════════════════════╡
│ 1   ┆ [["A", "B"], ["B", "C"]] │
│ 2   ┆ [["A", "B"], ["B", "C"]] │
│ 3   ┆ [["A", "B"], ["B", "C"]] │
└─────┴──────────────────────────┘

Another one

>>> input_df.group_by("idx").agg(
   pl.reduce(function=lambda acc, x: acc.list.set_intersection(x), 
   exprs=pl.col("values").explode())
)
shape: (3, 2)
┌─────┬───────────────────┐
│ idx ┆ values            │
│ --- ┆ ---               │
│ i64 ┆ list[str]         │
╞═════╪═══════════════════╡
│ 3   ┆ ["A", "B", … "C"] │
│ 2   ┆ ["A", "B", … "C"] │
│ 1   ┆ ["A", "B", … "C"] │
└─────┴───────────────────┘

Solution

  • I'm not sure if this is as simple as it may seem.

    You could get rid of the lists and use "regular" Polars functionality.

    One way to check if a value is contained in each row of the idx group is to count the number of unique (distinct) row numbers per idx, values group.

    (df.with_columns(group_len = pl.count().over("idx"))
       .with_row_count()
       .explode("values")
       .with_columns(num_rows =
           pl.n_unique("row_nr").over("idx", "values")
       )
    )
    
    shape: (12, 5)
    ┌────────┬─────┬────────┬───────────┬──────────┐
    │ row_nr ┆ idx ┆ values ┆ group_len ┆ num_rows │
    │ ---    ┆ --- ┆ ---    ┆ ---       ┆ ---      │
    │ u32    ┆ i64 ┆ str    ┆ u32       ┆ u32      │
    ╞════════╪═════╪════════╪═══════════╪══════════╡
    │ 0      ┆ 1   ┆ A      ┆ 2         ┆ 1        │
    │ 0      ┆ 1   ┆ B      ┆ 2         ┆ 2        │ # row_nr = [0, 1]
    │ 1      ┆ 1   ┆ B      ┆ 2         ┆ 2        │
    │ 1      ┆ 1   ┆ C      ┆ 2         ┆ 1        │
    │ 2      ┆ 2   ┆ A      ┆ 2         ┆ 1        │
    │ 2      ┆ 2   ┆ B      ┆ 2         ┆ 2        │ # row_nr = [2, 3]
    │ 3      ┆ 2   ┆ B      ┆ 2         ┆ 2        │
    │ 3      ┆ 2   ┆ C      ┆ 2         ┆ 1        │
    │ 4      ┆ 3   ┆ A      ┆ 2         ┆ 1        │
    │ 4      ┆ 3   ┆ B      ┆ 2         ┆ 2        │ # row_nr = [4, 5]
    │ 5      ┆ 3   ┆ B      ┆ 2         ┆ 2        │
    │ 5      ┆ 3   ┆ C      ┆ 2         ┆ 1        │
    └────────┴─────┴────────┴───────────┴──────────┘
    

    You can filter those, and build the list of results with .group_by

    (df.with_columns(group_len = pl.count().over("idx"))
       .with_row_count()
       .explode("values")
       .filter(
           pl.n_unique("row_nr").over("idx", "values")
            == pl.col("group_len")
       )
       .group_by("idx", maintain_order=True)
       .agg(pl.col("values").unique())
    )
    
    shape: (3, 2)
    ┌─────┬───────────┐
    │ idx ┆ values    │
    │ --- ┆ ---       │
    │ i64 ┆ list[str] │
    ╞═════╪═══════════╡
    │ 1   ┆ ["B"]     │
    │ 2   ┆ ["B"]     │
    │ 3   ┆ ["B"]     │
    └─────┴───────────┘