Search code examples
dataframepython-polars

How to get an element index in a list column, if element is specified in a different column


I have a dataframe, where one column a is a list, and another column b contains a value that's in a. I need to create the column c which contains index of the element in b in list a

df = pl.DataFrame({'a': [[1, 2, 3], [4, 5, 2], [6, 2, 7]], 'b': [3, 4, 2]})
print(df)

shape: (3, 2)
┌───────────┬─────┐
│ a         ┆ b   │
│ ---       ┆ --- │
│ list[i64] ┆ i64 │
╞═══════════╪═════╡
│ [1, 2, 3] ┆ 3   │
│ [4, 5, 2] ┆ 4   │
│ [6, 2, 7] ┆ 2   │
└───────────┴─────┘

so resulting dataframe looks like following

shape: (3, 3)
┌───────────┬─────┬────────────┐
│ a         ┆ b   ┆ a.index(b) │
│ ---       ┆ --- ┆ ---        │
│ list[i64] ┆ i64 ┆ i64        │
╞═══════════╪═════╪════════════╡
│ [1, 2, 3] ┆ 3   ┆ 2          │
│ [4, 5, 2] ┆ 4   ┆ 0          │
│ [6, 2, 7] ┆ 2   ┆ 1          │
└───────────┴─────┴────────────┘

all elements of a are unique within the row, and b is guaranteed to be in a.


Solution

  • We want to make two helper columns. one is just a copy of a and the other is the index values of a. We then explode by each of them and then filter for when b==acopy. Lastly, we clean it up by dropping our acopy

    (
        df
            .with_columns(
                pl.col('a').alias('acopy'), 
                pl.int_ranges(pl.col('a').list.len()).alias('a.index(b)'))
            .explode('acopy','a.index(b)')
            .filter(pl.col('acopy')==pl.col('b'))
            .drop('acopy')
    )
    
    shape: (3, 3)
    ┌───────────┬─────┬────────────┐
    │ a         ┆ b   ┆ a.index(b) │
    │ ---       ┆ --- ┆ ---        │
    │ list[i64] ┆ i64 ┆ i64        │
    ╞═══════════╪═════╪════════════╡
    │ [1, 2, 3] ┆ 3   ┆ 2          │
    │ [4, 5, 2] ┆ 4   ┆ 0          │
    │ [6, 2, 7] ┆ 2   ┆ 1          │
    └───────────┴─────┴────────────┘
    

    We can make this robust to dropping your guarantees about uniqueness (taking the first instance) and also allowing for the b value to not be in a while still returning the row with a null.

    So, say we started with:

    df = pl.DataFrame({'a': [[1, 2, 4], [4, 4, 2], [6, 2, 7]], 'b': [3, 4, 2]})
    

    then we could do

    (
        df
            .with_row_index('i')
            .with_columns(
                pl.col('a').alias('acopy'), 
                pl.int_ranges(pl.col('a').list.len()).alias('a.index(b)'))
            .explode('acopy','a.index(b)')
            .with_columns(
                (pl.when(pl.col('acopy')==pl.col('b'))
                    .then(pl.col('a.index(b)'))
                    .otherwise(None)).alias('a.index(b)')
                )
            .group_by('i', maintain_order=True)
            .agg(
                pl.exclude('a.index(b)').first(),
                pl.col('a.index(b)').min())
            .drop('i', 'acopy')
    )
    
    shape: (3, 3)
    ┌───────────┬─────┬────────────┐
    │ a         ┆ b   ┆ a.index(b) │
    │ ---       ┆ --- ┆ ---        │
    │ list[i64] ┆ i64 ┆ i64        │
    ╞═══════════╪═════╪════════════╡
    │ [1, 2, 4] ┆ 3   ┆ null       │
    │ [4, 4, 2] ┆ 4   ┆ 0          │
    │ [6, 2, 7] ┆ 2   ┆ 1          │
    └───────────┴─────┴────────────┘
    

    Essentially this trades the filter for a when.then.otherwise and group_by. For the group_by to work we have to create a row index helper column at the beginning.