Search code examples
pythonpython-polars

Add a column to a polars LazyFrame based on a group-by aggregation of another column


I have a LazyFrame of time, symbols and mid_price:

Example:

time                symbols             mid_price
datetime[ns]        str                 f64
2024-03-01 00:01:00 "PERP_SOL_USDT@…    126.1575
2024-03-01 00:01:00 "PERP_WAVES_USD…    2.71235
2024-03-01 00:01:00 "SOL_USDT@BINAN…    126.005
2024-03-01 00:01:00 "WAVES_USDT@BIN…    2.7085
2024-03-01 00:02:00 "PERP_SOL_USDT@…    126.3825

I want to perform some aggregations over the time dimension (ie: group by symbol):

aggs = (
    df
        .group_by('symbols')
        .agg([
            pl.col('mid_price').diff(1).alias("change"),
        ])
)

I get back a list of each value per unique symbols value:

>>> aggs.head().collect()

symbols             change
str                 list[f64]
"SOL_USDT@BINAN…    [null, 0.25, … -0.55]
"PERP_SOL_USDT@…    [null, 0.225, … -0.605]
"WAVES_USDT@BIN…    [null, -0.002, … -0.001]
"PERP_WAVES_USD…    [null, -0.00255, … 0.0001]

I would now like to join this back onto my original dataframe:

df = df.join(
    aggs,
    on='symbols',
    how='left',
)

This now results in each row getting the full list of change, rather then the respective value.

>>> df.head().collect()

time                symbols             mid_price   change
datetime[ns]        str                 f64         list[f64]
2024-03-01 00:01:00 "PERP_SOL_USDT@…    126.1575    [null, 0.225, … -0.605]
2024-03-01 00:01:00 "PERP_WAVES_USD…    2.71235     [null, -0.00255, … 0.0001]
2024-03-01 00:01:00 "SOL_USDT@BINAN…    126.005     [null, 0.25, … -0.55]
2024-03-01 00:01:00 "WAVES_USDT@BIN…    2.7085      [null, -0.002, … -0.001]
2024-03-01 00:02:00 "PERP_SOL_USDT@…    126.3825    [null, 0.225, … -0.605]

I have 2 questions please:

  1. How do I unstack/explode the lists returned from my group_by when joining them back into the original dataframe?
  2. Is this the recommended way to add a new column to my original dataframe from a group_by (that is: group_by followed by join)?

Solution

  • It sounds like you don't want to actually aggregate anything (and get a single value per symbol), but instead want to compute "change" but independently for each symbol.

    In polars, this kind of behaviour, similar to window functions in PostgreSQL, can be achieved with pl.Expr.over.

    df.with_columns(
        pl.col("mid_price").diff(1).over("symbol").alias("change")
    )
    

    On some example data, the resolt looks as follows.

    import polars as pl
    import numpy as np
    import datetime
    
    df = pl.DataFrame({
        "symbol": ["A"] * 3 + ["B"] * 3 + ["C"] * 3,
        "time": [datetime.datetime(2024, 3, 1, hour) for hour in range(3)] * 3,
        "mid_price": np.random.randn(9),
    })
    
    df.with_columns(
        pl.col("mid_price").diff(1).over("symbol").alias("change")
    )
    
    shape: (9, 4)
    ┌────────┬─────────────────────┬───────────┬───────────┐
    │ symbol ┆ time                ┆ mid_price ┆ change    │
    │ ---    ┆ ---                 ┆ ---       ┆ ---       │
    │ str    ┆ datetime[μs]        ┆ f64       ┆ f64       │
    ╞════════╪═════════════════════╪═══════════╪═══════════╡
    │ A      ┆ 2024-03-01 00:00:00 ┆ -0.349863 ┆ null      │
    │ A      ┆ 2024-03-01 01:00:00 ┆ 0.093732  ┆ 0.443595  │
    │ A      ┆ 2024-03-01 02:00:00 ┆ -1.262064 ┆ -1.355796 │
    │ B      ┆ 2024-03-01 00:00:00 ┆ 1.953929  ┆ null      │
    │ B      ┆ 2024-03-01 01:00:00 ┆ 0.637582  ┆ -1.316348 │
    │ B      ┆ 2024-03-01 02:00:00 ┆ 1.009401  ┆ 0.37182   │
    │ C      ┆ 2024-03-01 00:00:00 ┆ 0.75864   ┆ null      │
    │ C      ┆ 2024-03-01 01:00:00 ┆ -0.866227 ┆ -1.624867 │
    │ C      ┆ 2024-03-01 02:00:00 ┆ -0.674938 ┆ 0.191289  │
    └────────┴─────────────────────┴───────────┴───────────┘