Search code examples
pythonpivotpython-polars

How to explode a list column and proportionally split an integer column in python polars?


I have a dataset that includes item types and quantities of each item, except that some rows of the item type column contain lists of types instead of a single type. I want to explode the list of types into separate rows and split the quantities proportionally.

In pandas, I would typically explode and then groupby the index, similar to the process described here (the sample data shown on that page is analogous to the data I'm working with).

I can always create my own index in the polars dataframe, or convert to and back from Pandas in chucks that fit in memory, but is there a more better way to do this in Polars?

Edits:

There seemed to be some confusion in the comments, so for clarity, I have data shaped like so:

┌───────────────────┬─────┐
│ cat               ┆ qty │
│ ---               ┆ --- │
│ list[str]         ┆ i64 │
╞═══════════════════╪═════╡
│ ["green", "blue"] ┆ 23  │
│ ["green"]         ┆ 23  │
│ ["red"]           ┆ 4   │
│ ["blue"]          ┆ 5   │
│ ["red", "blue"]   ┆ 15  │
└───────────────────┴─────┘

And I want:

┌───────┬──────┐
│ cat   ┆ qty  │
│ ---   ┆ ---  │
│ str   ┆ f64  │
╞═══════╪══════╡
│ green ┆ 11.5 │
│ blue  ┆ 11.5 │
│ green ┆ 23.0 │
│ red   ┆ 4.0  │
│ blue  ┆ 5.0  │
│ red   ┆ 7.5  │
│ blue  ┆ 7.5  │
└───────┴──────┘

Which I can get with:

#convert to pandas
df_out = df.to_pandas()
#exploding cats
df_out = df_out.explode('cat')
#correct for multiplied quantity column
df_out['qty'] /= df_out['qty'].groupby(level=0).transform('count')
#back to polars
df_out = pl.DataFrame(df_out)
#print
print(df_out)

One could also do it with a script similar to this (thanks to @Henricks for finding that article!).

But I imagine there's a much better way to do this? Ideally in a way that can be done in lazy mode and without creating a new column to serve as a pseudo-index? But I haven't come up with anything that works. Thanks!


Solution

  • Similar to the pandas approach, you can divide qty by the number of elements in the corresponding cat list.

    (
        df
        .with_columns(pl.col("qty") / pl.col("cat").list.len())
        .explode("cat")
    )
    
    shape: (7, 2)
    ┌───────┬──────┐
    │ cat   ┆ qty  │
    │ ---   ┆ ---  │
    │ str   ┆ f64  │
    ╞═══════╪══════╡
    │ green ┆ 11.5 │
    │ blue  ┆ 11.5 │
    │ green ┆ 23.0 │
    │ red   ┆ 4.0  │
    │ blue  ┆ 5.0  │
    │ red   ┆ 7.5  │
    │ blue  ┆ 7.5  │
    └───────┴──────┘
    

    Note. In contrast to pandas, we perform this operation before exploding the list column (to still have access to the length of the lists).