Search code examples
python-polars

Index operation on list column data in polars


I have a column with a list of Strings for which I need to check if a particular String occurs before another.

import polars as pl

df = pl.LazyFrame(
    {
        'str': ['A', 'B', 'C', 'B', 'A'],
        'group': [1,1,2,1,2]
    }
)

df_groups = df.group_by('group').agg(pl.col('str').alias('str_list'))
print(df_groups.collect())
shape: (2, 2)
┌───────┬─────────────────┐
│ group ┆ str_list        │
│ ---   ┆ ---             │
│ i64   ┆ list[str]       │
╞═══════╪═════════════════╡
│ 1     ┆ ["A", "B", "B"] │
│ 2     ┆ ["C", "A"]      │
└───────┴─────────────────┘

I have created the following code example that works, but needs to break out of polars using map_elements, which makes it very slow.

pre = 'A'
succ = 'B'

df_groups_filtered = df_groups.filter(
    pl.col('str_list').map_elements(
        lambda str_list: 
            pre in str_list and succ in str_list and 
            str_list.to_list().index(pre) < str_list.to_list().index(succ)
    )
)

df_groups_filtered.collect()

This provides the desired result:

shape: (1, 2)
┌───────┬─────────────────┐
│ group ┆ str_list        │
│ ---   ┆ ---             │
│ i64   ┆ list[str]       │
╞═══════╪═════════════════╡
│ 1     ┆ ["A", "B", "B"] │
└───────┴─────────────────┘

I know that I can do

df_groups_filtered = df_groups.filter(
    pl.col('str_list').list.contains(pre) & col('str_list').list.contains(succ)
)

for the part of checking that both strings are contained, but I couldn't figure out how I can check the order in pure polars.

Are there ways to achieve this natively with polars?


Solution

  • I'd like to add another solution that I actually ended up adopting. Credits go to @mcrumiller, who posted this on Github.

    import polars as pl
    
    def loc_of(value):
        # only execute if the item is contained in the list
        return pl.when(pl.col("list").list.contains(value)).then(
            pl.col("list").list.eval(
                # create array of True/False, then cast to 1's and 0's
                # arg_max() then finds the first occurrence of 1, i.e. the first occurrence of value
                (pl.element() == value).cast(pl.UInt8).arg_max(),
                parallel=True
            ).list.first()
        )
    
    df = pl.DataFrame({
        "id": [1,2,3,4],
        "list": [['A', 'B'], ['B', 'A'], ['C', 'B', 'D'], ['D', 'A', 'C', 'B']]
    })
    
    df.filter(loc_of('A') < loc_of('B'))
    
    shape: (2, 2)
    ┌─────┬───────────────────┐
    │ id  ┆ list              │
    │ --- ┆ ---               │
    │ i64 ┆ list[str]         │
    ╞═════╪═══════════════════╡
    │ 1   ┆ ["A", "B"]        │
    │ 4   ┆ ["D", "A", … "B"] │
    └─────┴───────────────────┘
    

    I really like the simplicity of this approach. Performance-wise it is very similar to the approach of @user18559875.