Search code examples
python-polars

Filter by multiple items in lists?


Given a DataFrame with a list column and a list of items not in the data frame:

df = pl.DataFrame({
    "sets": [
        [1, 2, 3],
        [4],
        [9, 10],
        [2, 12],
        [6, 6, 1],
        [2, 0, 1],
        [1, 1, 4],
        [2, 7, 2],
    ]
})

items = [1, 2]

Is there an efficient way to filter the table to only have rows where the list column value contains a) one of the items in the list and b) all of the items in the list?

Expected result for ALL:

shape: (2, 1)
┌───────────┐
│ sets      │
│ ---       │
│ list[i64] │
╞═══════════╡
│ [1, 2, 3] │
│ [2, 0, 1] │
└───────────┘

Expected result for ANY:

shape: (6, 1)
┌───────────┐
│ sets      │
│ ---       │
│ list[i64] │
╞═══════════╡
│ [1, 2, 3] │
│ [2, 12]   │
│ [6, 6, 1] │
│ [2, 0, 1] │
│ [1, 1, 4] │
│ [2, 7, 2] │
└───────────┘

Solution

  • You can pass multiple tests to the any/all horizontal methods.

    ALL

    items = 1, 2
    
    df.filter(
       pl.all_horizontal(
          pl.lit(item).is_in(pl.col("sets")) for item in items
       )
    )
    
    shape: (2, 1)
    ┌───────────┐
    │ sets      │
    │ ---       │
    │ list[i64] │
    ╞═══════════╡
    │ [1, 2, 3] │
    │ [2, 0, 1] │
    └───────────┘
    
    • (note: .filter now accepts *args syntax, which can replace .all_horizontal)

    ANY

    df.filter(
       pl.any_horizontal(
          pl.lit(item).is_in(pl.col("sets")) for item in items
       )
    )
    
    shape: (6, 1)
    ┌───────────┐
    │ sets      │
    │ ---       │
    │ list[i64] │
    ╞═══════════╡
    │ [1, 2, 3] │
    │ [2, 12]   │
    │ [6, 6, 1] │
    │ [2, 0, 1] │
    │ [1, 1, 4] │
    │ [2, 7, 2] │
    └───────────┘