Search code examples
pythondataframevectorizationpython-polars

How can I map a field of a polars struct from values of another field `a`, to values of another field `b`?


I have a Polars dataframe with these columns. I want to replace the values in each list in column C with the corresponding value in column b, based on the position the value in c has in column a.

┌────────────────────┬─────────────────────────────────┬────────────────────┐
│ a                  ┆ b                               ┆ c                  │
│ ---                ┆ ---                             ┆ ---                │
│ list[f32]          ┆ list[f64]                       ┆ list[f32]          │
╞════════════════════╪═════════════════════════════════╪════════════════════╡
│ [1.0, 0.0]         ┆ [0.001143, 0.998857]            ┆ [0.0, 0.5, … 1.5]  │
│ [5.0, 6.0, … 4.0]  ┆ [0.000286, 0.000143, … 0.00357… ┆ [0.0, 0.5, … 6.5]  │
│ [1.0, 0.0]         ┆ [0.005287, 0.994713]            ┆ [0.0, 0.5, … 1.5]  │
│ [0.0, 1.5, … 2.5]  ┆ [0.84367, 0.003858, … 0.000429… ┆ [0.0, 0.5, … 3.5]  │
│ [5.0, 6.0, … 1.0]  ┆ [0.001286, 0.000286, … 0.35267… ┆ [0.0, 0.5, … 6.5]  │
│ …                  ┆ …                               ┆ …                  │
│ [0.0, 1.0]         ┆ [0.990283, 0.009717]            ┆ [0.0, 0.5, … 1.5]  │
│ [5.0, 1.0, … 0.0]  ┆ [0.003001, 0.352672, … 0.42855… ┆ [0.0, 0.5, … 6.5]  │
│ [0.0, 2.0, … 3.0]  ┆ [0.90383, 0.004716, … 0.000143… ┆ [0.0, 0.5, … 3.5]  │
│ [2.0, 0.0, … 9.0]  ┆ [0.233352, 0.060446, … 0.00228… ┆ [0.0, 0.5, … 10.5] │
│ [5.0, 8.0, … 11.0] ┆ [0.134467, 0.022578, … 0.00085… ┆ [0.0, 0.5, … 12.5] │
└────────────────────┴─────────────────────────────────┴────────────────────┘

Here is my attempt:

df = df.with_columns(
        pl.struct("a", "b", "c").alias("d").struct.with_fields(
            pl.field("c").replace(old=pl.field("a"), new=pl.field("b"), default=0)
    )
)

Unfortunately, this yields the error *** polars.exceptions.InvalidOperationError: `old` input for `replace` must not contain duplicates. However, "a", the field being passed to the old argument is column "a", which is the unique values from Expr.value_counts(), so it shouldn't contain any duplicates. And indeed, df.select(pl.col("lines").list.eval(pl.element().is_duplicated().any()).explode().any()) returns false.

Small chunk of the data to reproduce:

df = pl.DataFrame([
    pl.Series('a', [[1.0, 0.0], [5.0, 6.0, 1.0, 0.0, 3.0, 2.0, 4.0], [1.0, 0.0], [0.0, 1.5, 3.0, 1.0, 2.0, 0.5, 2.5], [5.0, 6.0, 0.0, 3.0, 2.0, 4.0, 1.0]]),
    pl.Series('b', [[0.0011431837667905116, 0.9988568162332095], [0.0002857959416976279, 0.00014289797084881395, 0.2842240640182909, 0.5985995998856817, 0.019291226064589884, 0.09388396684767077, 0.003572449271220349], [0.005287224921406116, 0.9947127750785939], [0.8436696198913975, 0.0038582452129179764, 0.00014289797084881395, 0.10703058016576164, 0.007859388396684767, 0.03701057444984281, 0.00042869391254644185], [0.0012860817376393256, 0.0002857959416976279, 0.4645613032294941, 0.038153758216633325, 0.13561017433552444, 0.007430694484138326, 0.3526721920548728]]),
    pl.Series('c', [[0.0, 0.5, 1.0, 1.5], [0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 6.5], [0.0, 0.5, 1.0, 1.5], [0.0, 0.5, 0.5, 1.0, 1.0, 1.5, 1.5, 2.0, 2.0, 2.5, 2.5, 3.0, 3.0, 3.5], [0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 6.5]]),
])

Example of what the output should be: For the first row, the new column/struct field should look like [0.998857, 0, ..., 0], because the values of "c" are [0.0, 0.5, ... 1.5] and 0 is at index 1 in column "a", and the value at index 1 in column "b" is 0.998857. The other values of 0.5 and 1.5 would be the default 0 as they do not appear in "a".

Is this just not possible? I am really hoping to find a vectorized way to do this.

Any help is appreciated, thanks.


Solution

  • Let's take an inventory of all the things we need for this to work which won't necessarily be the order we accomplish them, moreso in the order you might think of each thing.

    1. Compare list columns by element, this can't be done directly, need explode
    2. Since the columns to compare aren't equal size/shape, they should be separated in their own dfs before explode and then joined
    3. Create a row index of the original df so when the separated dfs are joined, the original row index will be one of the keys
    4. Create an element index of the a column since we'll need to use those values explicitly.
    5. group by the original df's row index to make a new list of index positions
    6. join the new df of list element index positions to the original df and gather by that new index column.
    7. replace the nulls with 0 of the final column

    Here's the code

    df = df.with_row_index("i") # item 3
    
    # item 2
    a = df.select("i", "a", 
                  ai=pl.int_ranges(0, pl.col("a").list.len()) # item 4
                  ).explode("a", "ai") # item 1
    c = df.select("i","c").explode("c") # item 1
    
    ai = (
        # item 2
        a.join(c, left_on=["a", "i"], right_on=["c", "i"], how="right")
        # item 5
        .group_by("i", maintain_order=True)
        .agg("ai")
    )
    
    df = (
        df
        .join(ai, on="i") #item 6
        .with_columns(
            z=pl.col("b").list.gather("ai")
            .list.eval(pl.element().fill_null(0)) #item 7
        )
        .drop('i')
    )