Search code examples
pythonpython-polars

Explode Polars rows on multiple columns but with different logic


I have this code, which splits a product column into a list, and then uses explode to expand it:

import polars as pl
import datetime as dt
from dateutil.relativedelta import relativedelta

def get_3_month_splits(product: str) -> list[str]:
    front, start_dt, total_m = product.rsplit('.', 2)
    start_dt = dt.datetime.strptime(start_dt, '%Y%m')
    total_m  = int(total_m)
    return [f'{front}.{(start_dt+relativedelta(months=m)).strftime("%Y%m")}.3' for m in range(0, total_m, 3)]

df = pl.DataFrame({
    'product':    ['CHECK.GB.202403.12', 'CHECK.DE.202506.6', 'CASH.US.202509.12'],
    'qty':        [10, -20, 50],
    'price_paid': [1400, -3300, 900],
})

print(df.with_columns(pl.col('product').map_elements(get_3_month_splits, return_dtype=pl.List(str))).explode('product'))

This currently gives

shape: (10, 3)
┌───────────────────┬─────┬────────────┐
│ product           ┆ qty ┆ price_paid │
│ ---               ┆ --- ┆ ---        │
│ str               ┆ i64 ┆ i64        │
╞═══════════════════╪═════╪════════════╡
│ CHECK.GB.202403.3 ┆ 10  ┆ 1400       │
│ CHECK.GB.202406.3 ┆ 10  ┆ 1400       │
│ CHECK.GB.202409.3 ┆ 10  ┆ 1400       │
│ CHECK.GB.202412.3 ┆ 10  ┆ 1400       │
│ CHECK.DE.202506.3 ┆ -20 ┆ -3300      │
│ CHECK.DE.202509.3 ┆ -20 ┆ -3300      │
│ CASH.US.202509.3  ┆ 50  ┆ 900        │
│ CASH.US.202512.3  ┆ 50  ┆ 900        │
│ CASH.US.202603.3  ┆ 50  ┆ 900        │
│ CASH.US.202606.3  ┆ 50  ┆ 900        │
└───────────────────┴─────┴────────────┘

However, I want to keep the total price paid the same. So after splitting the rows into several "sub categories", I want to change the table to this:

shape: (10, 3)
┌───────────────────┬─────┬────────────┐
│ product           ┆ qty ┆ price_paid │
│ ---               ┆ --- ┆ ---        │
│ str               ┆ i64 ┆ i64        │
╞═══════════════════╪═════╪════════════╡
│ CHECK.GB.202403.3 ┆ 10  ┆ 1400       │
│ CHECK.GB.202406.3 ┆ 10  ┆ 0          │
│ CHECK.GB.202409.3 ┆ 10  ┆ 0          │
│ CHECK.GB.202412.3 ┆ 10  ┆ 0          │
│ CHECK.DE.202506.3 ┆ -20 ┆ -3300      │
│ CHECK.DE.202509.3 ┆ -20 ┆ 0          │
│ CASH.US.202509.3  ┆ 50  ┆ 900        │
│ CASH.US.202512.3  ┆ 50  ┆ 0          │
│ CASH.US.202603.3  ┆ 50  ┆ 0          │
│ CASH.US.202606.3  ┆ 50  ┆ 0          │
└───────────────────┴─────┴────────────┘

i.e. only keeping the price_paid in the first expanded row. So my total price paid remains the same. The qty is okay to stay the way it is.

I tried e.g. with_columns(price_arr=pl.col('product').cast(pl.List(pl.Float64))) but was then unable to add anything to first element of the list. Or with_columns(price_arr=pl.col(['product', 'price_paid']).map_elements(price_func)) but it did not seem possible to use map_elements on pl.col([...]).


Solution

  • Concat the appropriate number of trailing 0s to price_paid before calling .explode() on both product and price_paid at once:

    print(
        df.with_columns(
            pl.col("product").map_elements(get_3_month_splits, return_dtype=pl.List(str))
        )
        .with_columns(
            pl.concat_list(
                pl.col("price_paid"), pl.lit(0).repeat_by(pl.col("product").list.len() - 1)
            )
        )
        .explode("product", "price_paid")
    )
    

    Output:

    shape: (10, 3)
    ┌───────────────────┬─────┬────────────┐
    │ product           ┆ qty ┆ price_paid │
    │ ---               ┆ --- ┆ ---        │
    │ str               ┆ i64 ┆ i64        │
    ╞═══════════════════╪═════╪════════════╡
    │ CHECK.GB.202403.3 ┆ 10  ┆ 1400       │
    │ CHECK.GB.202406.3 ┆ 10  ┆ 0          │
    │ CHECK.GB.202409.3 ┆ 10  ┆ 0          │
    │ CHECK.GB.202412.3 ┆ 10  ┆ 0          │
    │ CHECK.DE.202506.3 ┆ -20 ┆ -3300      │
    │ CHECK.DE.202509.3 ┆ -20 ┆ 0          │
    │ CASH.US.202509.3  ┆ 50  ┆ 900        │
    │ CASH.US.202512.3  ┆ 50  ┆ 0          │
    │ CASH.US.202603.3  ┆ 50  ┆ 0          │
    │ CASH.US.202606.3  ┆ 50  ┆ 0          │
    └───────────────────┴─────┴────────────┘