Search code examples
pythonpython-polars

Filter on `list(Int64)` dtype in polars


Say I have

In [20]: df = pl.DataFrame({'a': [[1,2,3], [1,4,2], [1,3,3]], 'b': [4,2,1]})

In [21]: df
Out[21]:
shape: (3, 2)
┌───────────┬─────┐
│ a         ┆ b   │
│ ---       ┆ --- │
│ list[i64] ┆ i64 │
╞═══════════╪═════╡
│ [1, 2, 3] ┆ 4   │
│ [1, 4, 2] ┆ 2   │
│ [1, 3, 3] ┆ 1   │
└───────────┴─────┘

I'd like to keep rows where 'a' equals [1,2,3]

I've tried

In [23]: df.filter(pl.col('a')==[1,2,3])

ArrowErrorException: NotYetImplemented("Casting from Int64 to LargeList(Field { name: \"item\", data_type: Int64, is_nullable: true, metadata: {} }) not supported")

but it raises


Solution

  • You can hash the list first and hash a literal and then compare the two:

    df.filter(pl.col('a').hash() == pl.lit([1,2,3]).hash())