I have this code, which splits a product
column into a list, and then uses explode
to expand it:
import polars as pl
import datetime as dt
from dateutil.relativedelta import relativedelta
def get_3_month_splits(product: str) -> list[str]:
front, start_dt, total_m = product.rsplit('.', 2)
start_dt = dt.datetime.strptime(start_dt, '%Y%m')
total_m = int(total_m)
return [f'{front}.{(start_dt+relativedelta(months=m)).strftime("%Y%m")}.3' for m in range(0, total_m, 3)]
df = pl.DataFrame({
'product': ['CHECK.GB.202403.12', 'CHECK.DE.202506.6', 'CASH.US.202509.12'],
'qty': [10, -20, 50],
'price_paid': [1400, -3300, 900],
})
print(df.with_columns(pl.col('product').map_elements(get_3_month_splits, return_dtype=pl.List(str))).explode('product'))
This currently gives
shape: (10, 3)
┌───────────────────┬─────┬────────────┐
│ product ┆ qty ┆ price_paid │
│ --- ┆ --- ┆ --- │
│ str ┆ i64 ┆ i64 │
╞═══════════════════╪═════╪════════════╡
│ CHECK.GB.202403.3 ┆ 10 ┆ 1400 │
│ CHECK.GB.202406.3 ┆ 10 ┆ 1400 │
│ CHECK.GB.202409.3 ┆ 10 ┆ 1400 │
│ CHECK.GB.202412.3 ┆ 10 ┆ 1400 │
│ CHECK.DE.202506.3 ┆ -20 ┆ -3300 │
│ CHECK.DE.202509.3 ┆ -20 ┆ -3300 │
│ CASH.US.202509.3 ┆ 50 ┆ 900 │
│ CASH.US.202512.3 ┆ 50 ┆ 900 │
│ CASH.US.202603.3 ┆ 50 ┆ 900 │
│ CASH.US.202606.3 ┆ 50 ┆ 900 │
└───────────────────┴─────┴────────────┘
However, I want to keep the total price paid
the same. So after splitting the rows into several "sub categories", I want to change the table to this:
shape: (10, 3)
┌───────────────────┬─────┬────────────┐
│ product ┆ qty ┆ price_paid │
│ --- ┆ --- ┆ --- │
│ str ┆ i64 ┆ i64 │
╞═══════════════════╪═════╪════════════╡
│ CHECK.GB.202403.3 ┆ 10 ┆ 1400 │
│ CHECK.GB.202406.3 ┆ 10 ┆ 0 │
│ CHECK.GB.202409.3 ┆ 10 ┆ 0 │
│ CHECK.GB.202412.3 ┆ 10 ┆ 0 │
│ CHECK.DE.202506.3 ┆ -20 ┆ -3300 │
│ CHECK.DE.202509.3 ┆ -20 ┆ 0 │
│ CASH.US.202509.3 ┆ 50 ┆ 900 │
│ CASH.US.202512.3 ┆ 50 ┆ 0 │
│ CASH.US.202603.3 ┆ 50 ┆ 0 │
│ CASH.US.202606.3 ┆ 50 ┆ 0 │
└───────────────────┴─────┴────────────┘
i.e. only keeping the price_paid
in the first expanded row. So my total price paid remains the same. The qty
is okay to stay the way it is.
I tried e.g. with_columns(price_arr=pl.col('product').cast(pl.List(pl.Float64)))
but was then unable to add anything to first element of the list. Or with_columns(price_arr=pl.col(['product', 'price_paid']).map_elements(price_func))
but it did not seem possible to use map_elements
on pl.col([...])
.
Concat the appropriate number of trailing 0
s to price_paid
before calling .explode()
on both product
and price_paid
at once:
print(
df.with_columns(
pl.col("product").map_elements(get_3_month_splits, return_dtype=pl.List(str))
)
.with_columns(
pl.concat_list(
pl.col("price_paid"), pl.lit(0).repeat_by(pl.col("product").list.len() - 1)
)
)
.explode("product", "price_paid")
)
Output:
shape: (10, 3)
┌───────────────────┬─────┬────────────┐
│ product ┆ qty ┆ price_paid │
│ --- ┆ --- ┆ --- │
│ str ┆ i64 ┆ i64 │
╞═══════════════════╪═════╪════════════╡
│ CHECK.GB.202403.3 ┆ 10 ┆ 1400 │
│ CHECK.GB.202406.3 ┆ 10 ┆ 0 │
│ CHECK.GB.202409.3 ┆ 10 ┆ 0 │
│ CHECK.GB.202412.3 ┆ 10 ┆ 0 │
│ CHECK.DE.202506.3 ┆ -20 ┆ -3300 │
│ CHECK.DE.202509.3 ┆ -20 ┆ 0 │
│ CASH.US.202509.3 ┆ 50 ┆ 900 │
│ CASH.US.202512.3 ┆ 50 ┆ 0 │
│ CASH.US.202603.3 ┆ 50 ┆ 0 │
│ CASH.US.202606.3 ┆ 50 ┆ 0 │
└───────────────────┴─────┴────────────┘