import polars as pl
df = pl.DataFrame({'a': [[1, 2, 3], [8, 9, 4]], 'b': [[2, 3, 4], [4, 5, 6]]})
So given the dataframe df
a b
[1, 2, 3] [2, 3, 4]
[8, 9, 4] [4, 5, 6]
I would like to get a column c, that is an intersection of a and b
a b c
[1, 2, 3] [2, 3, 4] [2, 3]
[8, 9, 4] [4, 5, 6] [4]
I know I can use the apply function with python set intersection, but I want to do it using polars expressions.
Use set operations for list:
df.select(
intersection = pl.col('a').list.set_intersection('b'),
difference = pl.col('a').list.set_difference('b'),
union = pl.col('a').list.set_union('b')
)
Use set operations for list (with old names):
df.select(
intersection = pl.col('a').list.intersection('b'),
difference = pl.col('a').list.difference('b'),
union = pl.col('a').list.union('b')
)
We can accomplish the intersection using the arr.eval
expression. The arr.eval
expression allows us to treat a list as a Series/column, so that we can use the same contexts and expressions that we use with columns and Series.
First, let's extend your example so that we can show what happens when the intersection is empty.
df = pl.DataFrame(
{
"a": [[1, 2, 3], [8, 9, 4], [0, 1, 2]],
"b": [[2, 3, 4], [4, 5, 6], [10, 11, 12]],
}
)
df
shape: (3, 2)
┌───────────┬──────────────┐
│ a ┆ b │
│ --- ┆ --- │
│ list[i64] ┆ list[i64] │
╞═══════════╪══════════════╡
│ [1, 2, 3] ┆ [2, 3, 4] │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ [8, 9, 4] ┆ [4, 5, 6] │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ [0, 1, 2] ┆ [10, 11, 12] │
└───────────┴──────────────┘
There are two ways to accomplish this. The first is extendable to the intersection of more than two sets (see Other Notes below).
df.with_column(
pl.col("a")
.arr.concat('b')
.arr.eval(pl.element().filter(pl.count().over(pl.element()) == 2))
.arr.unique()
.alias('intersection')
)
or
df.with_column(
pl.col("a")
.arr.concat('b')
.arr.eval(pl.element().filter(pl.element().is_duplicated()))
.arr.unique()
.alias('intersection')
)
shape: (3, 3)
┌───────────┬──────────────┬──────────────┐
│ a ┆ b ┆ intersection │
│ --- ┆ --- ┆ --- │
│ list[i64] ┆ list[i64] ┆ list[i64] │
╞═══════════╪══════════════╪══════════════╡
│ [1, 2, 3] ┆ [2, 3, 4] ┆ [2, 3] │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ [8, 9, 4] ┆ [4, 5, 6] ┆ [4] │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ [0, 1, 2] ┆ [10, 11, 12] ┆ [] │
└───────────┴──────────────┴──────────────┘
We first concatenate the two lists into a single list. Any element that is in both lists will appear twice.
df.with_column(
pl.col("a")
.arr.concat('b')
.alias('ablist')
)
shape: (3, 3)
┌───────────┬──────────────┬────────────────┐
│ a ┆ b ┆ ablist │
│ --- ┆ --- ┆ --- │
│ list[i64] ┆ list[i64] ┆ list[i64] │
╞═══════════╪══════════════╪════════════════╡
│ [1, 2, 3] ┆ [2, 3, 4] ┆ [1, 2, ... 4] │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ [8, 9, 4] ┆ [4, 5, 6] ┆ [8, 9, ... 6] │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ [0, 1, 2] ┆ [10, 11, 12] ┆ [0, 1, ... 12] │
└───────────┴──────────────┴────────────────┘
Then we can use the arr.eval
function which allows us to treat the concatenated list as if it is a Series/column. In this case, we'll use a filter
context to find any element that appears more than once. (The polars.element
expression in a list context is used like polars.col
is used in a Series.)
df.with_column(
pl.col("a")
.arr.concat('b')
.arr.eval(pl.element().filter(pl.count().over(pl.element()) == 2))
.alias('filtered')
)
shape: (3, 3)
┌───────────┬──────────────┬───────────────┐
│ a ┆ b ┆ filtered │
│ --- ┆ --- ┆ --- │
│ list[i64] ┆ list[i64] ┆ list[i64] │
╞═══════════╪══════════════╪═══════════════╡
│ [1, 2, 3] ┆ [2, 3, 4] ┆ [2, 3, ... 3] │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ [8, 9, 4] ┆ [4, 5, 6] ┆ [4, 4] │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ [0, 1, 2] ┆ [10, 11, 12] ┆ [] │
└───────────┴──────────────┴───────────────┘
Note: the above step can also be expressed using the is_duplicated
expression. (In the Other Notes section, we'll see that using is_duplicated
will not work when calculating the intersection of more than two sets.)
df.with_column(
pl.col("a")
.arr.concat('b')
.arr.eval(pl.element().filter(pl.element().is_duplicated()))
.alias('filtered')
)
All that remains is then to remove the duplicates from the results, using the arr.unique
expression (which is the result shown in the beginning).
I'm assuming that your lists are really sets, in that elements appear only once in each list. If there are duplicates in the original lists, we can apply arr.unique
to each list before the concatenation step.
Also, this process can be extended to find the intersection of more than two sets. Simply concatenate all the lists together, and then change the filter
step from == 2
to == n
(where n
is the number of sets). (Note: using the is_duplicated
expression above will not work with more than two sets.)
The arr.eval
method does have a parallel
keyword. You can try setting this to True
and see if it yields better performance in your particular situation.
Symmetric difference: change the filter
criterion to == 1
(and omit the arr.unique
step.)
Union: use arr.concat
followed by arr.unique
.
Set difference: compute the intersection (as above), then concatenate the original list/set and filter for items that appear only once. Alternatively, for small list sizes, you can concatenate “a” to itself and then to “b” and then filter for elements that occur twice (but not three times).