Search code examples
python-polars

polars intersection of list columns in dataframe


import polars as pl

df = pl.DataFrame({'a': [[1, 2, 3], [8, 9, 4]], 'b': [[2, 3, 4], [4, 5, 6]]})

So given the dataframe df

    a           b
[1, 2, 3]   [2, 3, 4]
[8, 9, 4]   [4, 5, 6]

I would like to get a column c, that is an intersection of a and b

    a           b          c
[1, 2, 3]   [2, 3, 4]    [2, 3]
[8, 9, 4]   [4, 5, 6]     [4]

I know I can use the apply function with python set intersection, but I want to do it using polars expressions.


Solution

  • polars >= 0.18.10

    Use set operations for list:

    df.select(
       intersection = pl.col('a').list.set_intersection('b'),
       difference = pl.col('a').list.set_difference('b'),
       union = pl.col('a').list.set_union('b')
    )  
    

    polars >= 0.18.5, polars < 0.18.10

    Use set operations for list (with old names):

    df.select(
       intersection = pl.col('a').list.intersection('b'),
       difference = pl.col('a').list.difference('b'),
       union = pl.col('a').list.union('b')
    )  
    

    polars < 0.18.5

    We can accomplish the intersection using the arr.eval expression. The arr.eval expression allows us to treat a list as a Series/column, so that we can use the same contexts and expressions that we use with columns and Series.

    First, let's extend your example so that we can show what happens when the intersection is empty.

    df = pl.DataFrame(
        {
            "a": [[1, 2, 3], [8, 9, 4], [0, 1, 2]],
            "b": [[2, 3, 4], [4, 5, 6], [10, 11, 12]],
        }
    )
    df
    
    shape: (3, 2)
    ┌───────────┬──────────────┐
    │ a         ┆ b            │
    │ ---       ┆ ---          │
    │ list[i64] ┆ list[i64]    │
    ╞═══════════╪══════════════╡
    │ [1, 2, 3] ┆ [2, 3, 4]    │
    ├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
    │ [8, 9, 4] ┆ [4, 5, 6]    │
    ├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
    │ [0, 1, 2] ┆ [10, 11, 12] │
    └───────────┴──────────────┘
    

    The Algorithm

    There are two ways to accomplish this. The first is extendable to the intersection of more than two sets (see Other Notes below).

    df.with_column(
        pl.col("a")
        .arr.concat('b')
        .arr.eval(pl.element().filter(pl.count().over(pl.element()) == 2))
        .arr.unique()
        .alias('intersection')
    )
    

    or

    df.with_column(
        pl.col("a")
        .arr.concat('b')
        .arr.eval(pl.element().filter(pl.element().is_duplicated()))
        .arr.unique()
        .alias('intersection')
    )
    
    shape: (3, 3)
    ┌───────────┬──────────────┬──────────────┐
    │ a         ┆ b            ┆ intersection │
    │ ---       ┆ ---          ┆ ---          │
    │ list[i64] ┆ list[i64]    ┆ list[i64]    │
    ╞═══════════╪══════════════╪══════════════╡
    │ [1, 2, 3] ┆ [2, 3, 4]    ┆ [2, 3]       │
    ├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
    │ [8, 9, 4] ┆ [4, 5, 6]    ┆ [4]          │
    ├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
    │ [0, 1, 2] ┆ [10, 11, 12] ┆ []           │
    └───────────┴──────────────┴──────────────┘
    

    How it works

    We first concatenate the two lists into a single list. Any element that is in both lists will appear twice.

    df.with_column(
        pl.col("a")
        .arr.concat('b')
        .alias('ablist')
    )
    
    shape: (3, 3)
    ┌───────────┬──────────────┬────────────────┐
    │ a         ┆ b            ┆ ablist         │
    │ ---       ┆ ---          ┆ ---            │
    │ list[i64] ┆ list[i64]    ┆ list[i64]      │
    ╞═══════════╪══════════════╪════════════════╡
    │ [1, 2, 3] ┆ [2, 3, 4]    ┆ [1, 2, ... 4]  │
    ├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
    │ [8, 9, 4] ┆ [4, 5, 6]    ┆ [8, 9, ... 6]  │
    ├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
    │ [0, 1, 2] ┆ [10, 11, 12] ┆ [0, 1, ... 12] │
    └───────────┴──────────────┴────────────────┘
    

    Then we can use the arr.eval function which allows us to treat the concatenated list as if it is a Series/column. In this case, we'll use a filter context to find any element that appears more than once. (The polars.element expression in a list context is used like polars.col is used in a Series.)

    df.with_column(
        pl.col("a")
        .arr.concat('b')
        .arr.eval(pl.element().filter(pl.count().over(pl.element()) == 2))
        .alias('filtered')
    )
    
    shape: (3, 3)
    ┌───────────┬──────────────┬───────────────┐
    │ a         ┆ b            ┆ filtered      │
    │ ---       ┆ ---          ┆ ---           │
    │ list[i64] ┆ list[i64]    ┆ list[i64]     │
    ╞═══════════╪══════════════╪═══════════════╡
    │ [1, 2, 3] ┆ [2, 3, 4]    ┆ [2, 3, ... 3] │
    ├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
    │ [8, 9, 4] ┆ [4, 5, 6]    ┆ [4, 4]        │
    ├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
    │ [0, 1, 2] ┆ [10, 11, 12] ┆ []            │
    └───────────┴──────────────┴───────────────┘
    

    Note: the above step can also be expressed using the is_duplicated expression. (In the Other Notes section, we'll see that using is_duplicated will not work when calculating the intersection of more than two sets.)

    df.with_column(
        pl.col("a")
        .arr.concat('b')
        .arr.eval(pl.element().filter(pl.element().is_duplicated()))
        .alias('filtered')
    )
    

    All that remains is then to remove the duplicates from the results, using the arr.unique expression (which is the result shown in the beginning).

    Other Notes

    I'm assuming that your lists are really sets, in that elements appear only once in each list. If there are duplicates in the original lists, we can apply arr.unique to each list before the concatenation step.

    Also, this process can be extended to find the intersection of more than two sets. Simply concatenate all the lists together, and then change the filter step from == 2 to == n (where n is the number of sets). (Note: using the is_duplicated expression above will not work with more than two sets.)

    The arr.eval method does have a parallel keyword. You can try setting this to True and see if it yields better performance in your particular situation.

    Other Set Operations

    Symmetric difference: change the filter criterion to == 1 (and omit the arr.unique step.)

    Union: use arr.concat followed by arr.unique.

    Set difference: compute the intersection (as above), then concatenate the original list/set and filter for items that appear only once. Alternatively, for small list sizes, you can concatenate “a” to itself and then to “b” and then filter for elements that occur twice (but not three times).