Search code examples
pythonpython-polars

Lookup column name in column


Say I have

df = pl.DataFrame({
    'a': [1, 2, 1],
    'b': [2, 1, 2],
    'c': [3, 3, 2],
    'column': ['a', 'c', 'b'],
})
shape: (3, 4)
┌─────┬─────┬─────┬────────┐
│ a   ┆ b   ┆ c   ┆ column │
│ --- ┆ --- ┆ --- ┆ ---    │
│ i64 ┆ i64 ┆ i64 ┆ str    │
╞═════╪═════╪═════╪════════╡
│ 1   ┆ 2   ┆ 3   ┆ a      │
│ 2   ┆ 1   ┆ 3   ┆ c      │
│ 1   ┆ 2   ┆ 2   ┆ b      │
└─────┴─────┴─────┴────────┘

I want add a column which, for each row, take the value in the row corresponding to the column in column.

Expected output:

shape: (3, 5)
┌─────┬─────┬─────┬────────┬────────┐
│ a   ┆ b   ┆ c   ┆ column ┆ lookup │
│ --- ┆ --- ┆ --- ┆ ---    ┆ ---    │
│ i64 ┆ i64 ┆ i64 ┆ str    ┆ i64    │
╞═════╪═════╪═════╪════════╪════════╡
│ 1   ┆ 2   ┆ 3   ┆ a      ┆ 1      │
│ 2   ┆ 1   ┆ 3   ┆ c      ┆ 3      │
│ 1   ┆ 2   ┆ 2   ┆ b      ┆ 2      │
└─────┴─────┴─────┴────────┴────────┘

Solution

  • This works even with repeated columns

    df = pl.DataFrame({
        'a': [1, 2, 1,3],
        'b': [2, 1, 2,4],
        'c': [3, 3, 2,5],
        'column': ['a', 'c', 'b','a'],
    })
    df.with_row_count('i').join(
        df.with_row_count('i')
            .melt(['i','column'], value_name='lookup')
            .filter(pl.col('column')==pl.col('variable'))
            .select('i','lookup'),
        on='i', how='left').sort('i').drop('i')
    
    shape: (4, 5)
    ┌─────┬─────┬─────┬────────┬────────┐
    │ a   ┆ b   ┆ c   ┆ column ┆ lookup │
    │ --- ┆ --- ┆ --- ┆ ---    ┆ ---    │
    │ i64 ┆ i64 ┆ i64 ┆ str    ┆ i64    │
    ╞═════╪═════╪═════╪════════╪════════╡
    │ 1   ┆ 2   ┆ 3   ┆ a      ┆ 1      │
    │ 2   ┆ 1   ┆ 3   ┆ c      ┆ 3      │
    │ 1   ┆ 2   ┆ 2   ┆ b      ┆ 2      │
    │ 3   ┆ 4   ┆ 5   ┆ a      ┆ 3      │
    └─────┴─────┴─────┴────────┴────────┘
    

    Here's another way that doesn't use melt

    df.join(
        df.select(pl.col('column').unique().sort()).with_row_count('i'), 
        on='column'
    ).with_columns(
        lookup=(pl.concat_list('a','b','c').list.get(pl.col('i')))
    ).drop('i')
    

    It works by converting the 'a','b','c' columns into a list column and then using .get to extract the index of that list since get will take an expression. For that to work we first have to make a lookup table of what index value abc are. It obviously breaks badly if the existing columns don't perfectly match what's in column. Assuming your data is scrubbed in advance then I think this would be faster than the melt.