Search code examples
python-polars

Best way to group_by inside a rolling in polars?


Given a dateset that looks like this where each node has one row per hour. Bid type is simply a calculated column where values greater than 5 is Buy, less than 2 is Sell, otherwise it's None.

Node date Hour Price BidType
1349561 N001 2020-12-13 00:00:00 17 30.63 Buy
391333 CE 2020-12-17 00:00:00 13 -2.42 Sell
784166 N002 2020-12-05 00:00:00 14 -0.92 Sell
1191909 MMM 2020-12-14 00:00:00 21 -1.69 Sell
44068 MMS 2020-12-08 00:00:00 4 2.07 None

In polars, I want to do a rolling average of the price, but grouped by Node, Hour and BidType. So let's say for date 2020-12-13, for Node N001 for Hour 17, for BidType='Buy', I want to take the last 100 days and take the mean of all the rows that correspond to Node=N001, Hour=17 and BidType='Buy' and calculate the mean price. Do that for all combination of Node, Hour, BidType...

My first instinct was to do:

.rolling('date', period='100d', by=['Node', 'Hour', 'BidType']).agg(pl.col('Price').mean())

but it doesn't work. It returns only 24 rows per date per node... But I would expect it to have 3 * 24 rows per node since there's 3 possible BidType * 24 hours for each node.

Another try would be:

.rolling('date', period='10d', by=['Node', 'Hour'])
    .agg(
        pl.col('Price').filter(pl.col('BidType') == 'Buy').mean().alias('Price_buy'),
        pl.col('Price').filter(pl.col('BidType') == 'Sell').mean().alias('Price_sell'),
        pl.col('Price').filter(pl.col('BidType') == 'None').mean().alias('Price_none')
    )
    .melt(['Node', 'Hour', 'date'])

But this doesn't seem to return what I want either.

What is the recommended way of doing what I want?

Here is some code to generate a dataset to test the problem:

df = pd.DataFrame(np.random.randint(0,10,size=(24*10*100, 1)), columns=['Price'])
df['Node'] = np.concatenate([np.repeat(np.arange(0, 10), 24) for x in range(0, 100)])
df['date'] = np.concatenate([pd.date_range('1/1/2020', periods=24, freq='D') for x in range(0,100*10)])
df['Hour'] = np.concatenate([np.arange(0, 24) for x in range(0, 100*10)])
df['Time'] = df.date + pd.to_timedelta(df.Hour, unit='h')

df = pl.DataFrame(df)
df = df.sort('date', 'Hour')
df = df.with_columns(pl.when(pl.col('Price') >= 5).then(pl.lit('Buy')).otherwise(pl.when(pl.col('Price') < 2).then(pl.lit('Sell')).otherwise(pl.lit('None'))).alias('BidType'))
df

Solution

  • To keep the example clear, I am working with a much smaller dataframe. You can find the code to generate the dataframe at the end of the question.

    shape: (8, 4)
    ┌──────┬─────────────────────┬───────┬──────────┐
    │ node ┆ date                ┆ price ┆ bid_type │
    │ ---  ┆ ---                 ┆ ---   ┆ ---      │
    │ str  ┆ datetime[μs]        ┆ i64   ┆ str      │
    ╞══════╪═════════════════════╪═══════╪══════════╡
    │ a    ┆ 2021-01-01 00:00:00 ┆ 5     ┆ null     │
    │ a    ┆ 2021-01-01 01:00:00 ┆ 0     ┆ Sell     │
    │ a    ┆ 2021-01-02 00:00:00 ┆ 3     ┆ null     │
    │ a    ┆ 2021-01-02 01:00:00 ┆ 3     ┆ null     │
    │ b    ┆ 2021-01-01 00:00:00 ┆ 7     ┆ Buy      │
    │ b    ┆ 2021-01-01 01:00:00 ┆ 9     ┆ Buy      │
    │ b    ┆ 2021-01-02 00:00:00 ┆ 3     ┆ null     │
    │ b    ┆ 2021-01-02 01:00:00 ┆ 5     ┆ null     │
    └──────┴─────────────────────┴───────┴──────────┘
    

    As you can see, the dataframe contains data for only 2 nodes, 2 days, and 2 hours per day.

    Problem diagnostics. pl.DataFrame.rolling will, for each row in the dataframe, create a look-back window of length period that contains all rows matching the current rows' by values.

    Especially, in the example above, there won't be a row for node == "a", date == 2021-01-01 00:00:00, and bid_type == "Buy", as there is no corresponding row in the original dataframe. Therefore, naively applying pl.DataFrame.rolling().agg() will produce the following windows.

    shape: (8, 5)
    ┌──────┬──────┬──────────┬─────────────────────┬───────────┐
    │ node ┆ hour ┆ bid_type ┆ date                ┆ price     │
    │ ---  ┆ ---  ┆ ---      ┆ ---                 ┆ ---       │
    │ str  ┆ i8   ┆ str      ┆ datetime[μs]        ┆ list[i64] │
    ╞══════╪══════╪══════════╪═════════════════════╪═══════════╡
    │ a    ┆ 0    ┆ null     ┆ 2021-01-01 00:00:00 ┆ [5]       │
    │ a    ┆ 1    ┆ Sell     ┆ 2021-01-01 01:00:00 ┆ [0]       │
    │ a    ┆ 0    ┆ null     ┆ 2021-01-02 00:00:00 ┆ [5, 3]    │
    │ a    ┆ 1    ┆ null     ┆ 2021-01-02 01:00:00 ┆ [3]       │
    │ b    ┆ 0    ┆ Buy      ┆ 2021-01-01 00:00:00 ┆ [7]       │
    │ b    ┆ 1    ┆ Buy      ┆ 2021-01-01 01:00:00 ┆ [9]       │
    │ b    ┆ 0    ┆ null     ┆ 2021-01-02 00:00:00 ┆ [3]       │
    │ b    ┆ 1    ┆ null     ┆ 2021-01-02 01:00:00 ┆ [5]       │
    └──────┴──────┴──────────┴─────────────────────┴───────────┘
    

    Answer. To achieve the desired effect, you can create a new dataframe, which for each node, date, bid_type combination has a corresponding row. price in the new dataframe is missing if and only if the original dataframe did not contain the given combination.

    # create dense dataframe
    skeleton = (
        df
        .select(pl.col("node").unique())
        .join(df.select("date").unique(), how="cross")
        .join(df.select("bid_type").unique(), how="cross")
        .sort("node", "date", "bid_type")
    )
    
    # join with original prices
    skeleton.join(df, on=["node", "date", "bid_type"], how="left", join_nulls=True)
    
    shape: (24, 4)
    ┌──────┬─────────────────────┬──────────┬───────┐
    │ node ┆ date                ┆ bid_type ┆ price │
    │ ---  ┆ ---                 ┆ ---      ┆ ---   │
    │ str  ┆ datetime[μs]        ┆ str      ┆ i64   │
    ╞══════╪═════════════════════╪══════════╪═══════╡
    │ a    ┆ 2021-01-01 00:00:00 ┆ null     ┆ 5     │
    │ a    ┆ 2021-01-01 00:00:00 ┆ Buy      ┆ null  │
    │ a    ┆ 2021-01-01 00:00:00 ┆ Sell     ┆ null  │
    │ a    ┆ 2021-01-01 01:00:00 ┆ null     ┆ null  │
    │ a    ┆ 2021-01-01 01:00:00 ┆ Buy      ┆ null  │
    │ a    ┆ 2021-01-01 01:00:00 ┆ Sell     ┆ 0     │
    │ a    ┆ 2021-01-02 00:00:00 ┆ null     ┆ 3     │
    │ a    ┆ 2021-01-02 00:00:00 ┆ Buy      ┆ null  │
    │ a    ┆ 2021-01-02 00:00:00 ┆ Sell     ┆ null  │
    │ a    ┆ 2021-01-02 01:00:00 ┆ null     ┆ 3     │
    │ a    ┆ 2021-01-02 01:00:00 ┆ Buy      ┆ null  │
    │ a    ┆ 2021-01-02 01:00:00 ┆ Sell     ┆ null  │
    │ b    ┆ 2021-01-01 00:00:00 ┆ null     ┆ null  │
    │ b    ┆ 2021-01-01 00:00:00 ┆ Buy      ┆ 7     │
    │ b    ┆ 2021-01-01 00:00:00 ┆ Sell     ┆ null  │
    │ b    ┆ 2021-01-01 01:00:00 ┆ null     ┆ null  │
    │ b    ┆ 2021-01-01 01:00:00 ┆ Buy      ┆ 9     │
    │ b    ┆ 2021-01-01 01:00:00 ┆ Sell     ┆ null  │
    │ b    ┆ 2021-01-02 00:00:00 ┆ null     ┆ 3     │
    │ b    ┆ 2021-01-02 00:00:00 ┆ Buy      ┆ null  │
    │ b    ┆ 2021-01-02 00:00:00 ┆ Sell     ┆ null  │
    │ b    ┆ 2021-01-02 01:00:00 ┆ null     ┆ 5     │
    │ b    ┆ 2021-01-02 01:00:00 ┆ Buy      ┆ null  │
    │ b    ┆ 2021-01-02 01:00:00 ┆ Sell     ┆ null  │
    └──────┴─────────────────────┴──────────┴───────┘
    

    Applying pl.DataFrame.rolling produces the desired rolling windows.

    (
        skeleton
        .join(df, on=["node", "date", "bid_type"], how="left", join_nulls=True)
        .rolling(
            index_column="date",
            period="2d",
            by=["node", pl.col("date").dt.hour().alias("hour"), pl.col("bid_type")],
        )
        .agg(
            pl.col("price")
        )
        .sort("node", "date", "bid_type")
    )
    

    Note. Use pl.col("price").mean() in the aggregation to obtain the rolling average instead of the rolling windows.

    shape: (24, 5)
    ┌──────┬──────┬──────────┬─────────────────────┬──────────────┐
    │ node ┆ hour ┆ bid_type ┆ date                ┆ price        │
    │ ---  ┆ ---  ┆ ---      ┆ ---                 ┆ ---          │
    │ str  ┆ i8   ┆ str      ┆ datetime[μs]        ┆ list[i64]    │
    ╞══════╪══════╪══════════╪═════════════════════╪══════════════╡
    │ a    ┆ 0    ┆ null     ┆ 2021-01-01 00:00:00 ┆ [5]          │
    │ a    ┆ 0    ┆ Buy      ┆ 2021-01-01 00:00:00 ┆ [null]       │
    │ a    ┆ 0    ┆ Sell     ┆ 2021-01-01 00:00:00 ┆ [null]       │
    │ a    ┆ 1    ┆ null     ┆ 2021-01-01 01:00:00 ┆ [null]       │
    │ a    ┆ 1    ┆ Buy      ┆ 2021-01-01 01:00:00 ┆ [null]       │
    │ a    ┆ 1    ┆ Sell     ┆ 2021-01-01 01:00:00 ┆ [0]          │
    │ a    ┆ 0    ┆ null     ┆ 2021-01-02 00:00:00 ┆ [5, 3]       │
    │ a    ┆ 0    ┆ Buy      ┆ 2021-01-02 00:00:00 ┆ [null, null] │
    │ a    ┆ 0    ┆ Sell     ┆ 2021-01-02 00:00:00 ┆ [null, null] │
    │ a    ┆ 1    ┆ null     ┆ 2021-01-02 01:00:00 ┆ [null, 3]    │
    │ a    ┆ 1    ┆ Buy      ┆ 2021-01-02 01:00:00 ┆ [null, null] │
    │ a    ┆ 1    ┆ Sell     ┆ 2021-01-02 01:00:00 ┆ [0, null]    │
    │ b    ┆ 0    ┆ null     ┆ 2021-01-01 00:00:00 ┆ [null]       │
    │ b    ┆ 0    ┆ Buy      ┆ 2021-01-01 00:00:00 ┆ [7]          │
    │ b    ┆ 0    ┆ Sell     ┆ 2021-01-01 00:00:00 ┆ [null]       │
    │ b    ┆ 1    ┆ null     ┆ 2021-01-01 01:00:00 ┆ [null]       │
    │ b    ┆ 1    ┆ Buy      ┆ 2021-01-01 01:00:00 ┆ [9]          │
    │ b    ┆ 1    ┆ Sell     ┆ 2021-01-01 01:00:00 ┆ [null]       │
    │ b    ┆ 0    ┆ null     ┆ 2021-01-02 00:00:00 ┆ [null, 3]    │
    │ b    ┆ 0    ┆ Buy      ┆ 2021-01-02 00:00:00 ┆ [7, null]    │
    │ b    ┆ 0    ┆ Sell     ┆ 2021-01-02 00:00:00 ┆ [null, null] │
    │ b    ┆ 1    ┆ null     ┆ 2021-01-02 01:00:00 ┆ [null, 5]    │
    │ b    ┆ 1    ┆ Buy      ┆ 2021-01-02 01:00:00 ┆ [9, null]    │
    │ b    ┆ 1    ┆ Sell     ┆ 2021-01-02 01:00:00 ┆ [null, null] │
    └──────┴──────┴──────────┴─────────────────────┴──────────────┘
    

    Sanity check. As a sanity check, we can set N_NODES = 1, N_DAYS = 365, and N_HOURS = 1 in the code for data generation and also set period="365d" in the call to pl.DataFrame.rolling. This will ensure, that the rolling windows of later timestamps contain sufficiently many data points to obtain reasonable monte-carlo estimates. We get price means of around 0.5, 3.5, and 7.5 for windows of bid type "Sell", None, and "Buy", respectively. This nicely matches the expected value of a discrete uniform distribution supported on {0, 1}, {2, 3, 4, 5}, and {6, 7, 8, 9}.

    Code for data generation.

    import string
    import polars as pl
    import numpy as np
    
    np.random.seed(0)
    
    N_NODES = 2
    N_DAYS = 2
    N_HOURS = 2
    
    # create df of unique nodes
    df_nodes = pl.DataFrame({"node": list(string.ascii_letters[:N_NODES])})
    
    # create df of unique dates
    df_dates = (
        pl.DataFrame({
            "date": pl.datetime_range(
                pl.datetime(2021, 1, 1),
                pl.datetime(2021, 1, 1).dt.offset_by(f"{N_DAYS}d").dt.offset_by("-1h"),
                interval="1h",
                eager=True,
            )
        })
        .filter(pl.col("date").dt.hour() < N_HOURS)
    )
    
    # create df of all node / date combinations
    df = df_nodes.join(df_dates, how="cross")
    
    # add random price and bid type
    df = (
        df
        .with_columns(
            pl.lit(np.random.randint(0, 10, df.height)).alias("price"),
        )
        .with_columns(
            pl.coalesce(
                pl.when(pl.col("price") > 5).then(pl.lit("Buy")),
                pl.when(pl.col("price") < 2).then(pl.lit("Sell")),
            ).alias("bid_type")
        )
    )
    df
    
    # create df of unique nodes
    df_nodes = pl.DataFrame({"node": list(string.ascii_letters[:N_NODES])})
    
    # create df of unique dates
    df_dates = (
        pl.DataFrame({
            "date": pl.datetime_range(
                pl.datetime(2021, 1, 1),
                pl.datetime(2021, 1, 1).dt.offset_by(f"{N_DAYS}d").dt.offset_by("-1h"),
                interval="1h",
                eager=True,
            )
        })
        .filter(pl.col("date").dt.hour() < N_HOURS)
    )
    
    # create df of all node / date combinations
    df = df_nodes.join(df_dates, how="cross")
    
    # add random price and bid type
    df = (
        df
        .with_columns(
            pl.lit(np.random.randint(0, 10, df.height)).alias("price"),
        )
        .with_columns(
            pl.coalesce(
                pl.when(pl.col("price") > 5).then(pl.lit("Buy")),
                pl.when(pl.col("price") < 2).then(pl.lit("Sell")),
            ).alias("bid_type")
        )
    )