Search code examples
pythonpython-polars

How to add numeric value from one column to other List colum elements in Polars?


Suppose I have the following Polars DataFame:

import polars as pl

df = pl.DataFrame({
    'lst': [[0, 1], [9, 8]],
    'val': [3, 4]
})

And I want to add the number in the val column, to every element in the corresponding list in the lst column, to get the following result:

┌───────────┬─────┐
│ lst       ┆ val │
│ ---       ┆ --- │
│ list[i64] ┆ i64 │
╞═══════════╪═════╡
│ [3, 4]    ┆ 3   │
│ [13, 12]  ┆ 4   │
└───────────┴─────┘

I know how to add a constant value, e.g.:

new_df = df.with_columns(
    pl.col('lst').list.eval(pl.element() + 2)
)

But when I try:

new_df = df.with_columns(
    pl.col('lst').list.eval(pl.element() + pl.col('val'))
)

I get the following error:

polars.exceptions.ComputeError: named columns are not allowed in `list.eval`; consider using `element` or `col("")`

Is there any elegant way to achieve my goal (without map_elements)?

Thanks in advance.


Solution

  • If lst has a fixed number of items for the whole column AND you know how many in advance then you can do this:

    df.with_columns(
        pl.concat_list(pl.col('lst').list.get(x) + pl.col('val') for x in range(2))
    )
    

    If not, then you can still do this:

    df.with_columns(
        pl.concat_list(
            pl.col('lst').list.get(x) + pl.col('val') 
            for x in range(df['lst'].list.len().max())
            )
        .list.gather(
            pl.int_ranges(0,pl.col('lst').list.len())
        )
    )
    

    The way this works is that it replaces the 2 from the first case with df['lst'].list.len().max() which is the longest list in the whole column. Then it does a gather so that it only takes the first n elements that correspond to the length of each particular lst.

    concat_list isn't super efficient so it might be the case that explode, group_by is more efficient.

    (df
     .with_row_index('i')
     .explode('lst')
     .group_by('i',maintain_order=True)
     .agg(
         pl.col('lst')+pl.col('val'),
         pl.col('val').first(),
         )
     .drop('i')
    )