Search code examples
pythondataframelistaggregationpython-polars

Aggregate column with list of string with intersection of the elements with Polars


I'm trying to aggregate some rows in my dataframe with a list[str] column. For each id I need the intersection of all the lists in the group. Not sure if I'm just overthinking it but I can't provide a solution right now. Any help please?

df = pl.DataFrame(
   {"id": [1,1,2,2,3,3], 
    "values": [["A", "B"], ["B", "C"], ["A", "B"], ["B", "C"], ["A", "B"], ["B", "C"]]
   }
)

Expected output

shape: (3, 2)
┌─────┬───────────┐
│ idx ┆ values    │
│ --- ┆ ---       │
│ i64 ┆ list[str] │
╞═════╪═══════════╡
│ 1   ┆ ["B"]     │
│ 2   ┆ ["B"]     │
│ 3   ┆ ["B"]     │
└─────┴───────────┘

I've tried some stuff without success

df.group_by("id").agg(
    pl.reduce(function=lambda acc, x: acc.list.set_intersection(x), 
              exprs=pl.col("values"))
)

# shape: (3, 2)
# ┌─────┬──────────────────────────┐
# │ id  ┆ values                   │
# │ --- ┆ ---                      │
# │ i64 ┆ list[list[str]]          │
# ╞═════╪══════════════════════════╡
# │ 1   ┆ [["A", "B"], ["B", "C"]] │
# │ 3   ┆ [["A", "B"], ["B", "C"]] │
# │ 2   ┆ [["A", "B"], ["B", "C"]] │
# └─────┴──────────────────────────┘

Another one

df.group_by("id").agg(
    pl.reduce(function=lambda acc, x: acc.list.set_intersection(x), 
              exprs=pl.col("values").explode())
)

# shape: (3, 2)
# ┌─────┬──────────────────────┐
# │ id  ┆ values               │
# │ --- ┆ ---                  │
# │ i64 ┆ list[str]            │
# ╞═════╪══════════════════════╡
# │ 3   ┆ ["A", "B", "B", "C"] │
# │ 1   ┆ ["A", "B", "B", "C"] │
# │ 2   ┆ ["A", "B", "B", "C"] │
# └─────┴──────────────────────┘

Solution

  • I'm not sure if this is as simple as it may first seem.

    You could get rid of the lists and use "regular" Polars functionality.

    One way to check if a value is contained in each row of the id group is to count the number of unique (distinct) row numbers per id, values group.

    (df.with_columns(group_len = pl.len().over("id"))
       .with_row_index()
       .explode("values")
       .with_columns(n_unique = pl.col.index.n_unique().over("id", "values"))
    )
    
    shape: (12, 5)
    ┌────────┬─────┬────────┬───────────┬──────────┐
    │ index  ┆ id  ┆ values ┆ group_len ┆ n_unique │
    │ ---    ┆ --- ┆ ---    ┆ ---       ┆ ---      │
    │ u32    ┆ i64 ┆ str    ┆ u32       ┆ u32      │
    ╞════════╪═════╪════════╪═══════════╪══════════╡
    │ 0      ┆ 1   ┆ A      ┆ 2         ┆ 1        │
    │ 0      ┆ 1   ┆ B      ┆ 2         ┆ 2        │ # index = [0, 1]
    │ 1      ┆ 1   ┆ B      ┆ 2         ┆ 2        │
    │ 1      ┆ 1   ┆ C      ┆ 2         ┆ 1        │
    │ 2      ┆ 2   ┆ A      ┆ 2         ┆ 1        │
    │ 2      ┆ 2   ┆ B      ┆ 2         ┆ 2        │ # index = [2, 3]
    │ 3      ┆ 2   ┆ B      ┆ 2         ┆ 2        │
    │ 3      ┆ 2   ┆ C      ┆ 2         ┆ 1        │
    │ 4      ┆ 3   ┆ A      ┆ 2         ┆ 1        │
    │ 4      ┆ 3   ┆ B      ┆ 2         ┆ 2        │ # index = [4, 5]
    │ 5      ┆ 3   ┆ B      ┆ 2         ┆ 2        │
    │ 5      ┆ 3   ┆ C      ┆ 2         ┆ 1        │
    └────────┴─────┴────────┴───────────┴──────────┘
    

    You can filter those, and rebuild the lists with .group_by()

    (df.with_columns(pl.len().over("id").alias("group_len"))
       .with_row_index()
       .explode("values")
       .filter(
          pl.col.index.n_unique().over("id", "values")
          == pl.col.group_len
       )
       .group_by("id", maintain_order=True)
       .agg(pl.col.values.unique())
     )
    
    shape: (3, 2)
    ┌─────┬───────────┐
    │ idx ┆ values    │
    │ --- ┆ ---       │
    │ i64 ┆ list[str] │
    ╞═════╪═══════════╡
    │ 1   ┆ ["B"]     │
    │ 2   ┆ ["B"]     │
    │ 3   ┆ ["B"]     │
    └─────┴───────────┘