Search code examples
pythongroup-byaggregatepython-polars

Polars rolling only retains index and by columns


Using the same dataframe as in a previous question,

df = pl.DataFrame(
    [
        pl.Series("Time", ['02/01/2018 07:05', '02/01/2018 07:07', '02/01/2018 07:08', '02/01/2018 07:09', '02/01/2018 07:10', '02/01/2018 07:12', '02/01/2018 07:13', '02/01/2018 07:14', '02/01/2018 07:18', '02/01/2018 07:26', '02/01/2018 07:38', '02/01/2018 07:39', '02/01/2018 07:45', '02/01/2018 07:48', '02/01/2018 07:49', '02/01/2018 07:50', '02/01/2018 07:52', '02/01/2018 07:53', '02/01/2018 07:56', '02/01/2018 07:57'], dtype=pl.String),
        pl.Series("Open", [8.05, 8.04, 8.02, 8.02, 8.02, 8.01, 8.01, 8.0, 7.99, 7.99, 7.99, 7.99, 7.99, 7.98, 7.96, 7.96, 7.95, 7.94, 7.94, 7.93], dtype=pl.Float64),
        pl.Series("High", [8.05, 8.04, 8.02, 8.02, 8.02, 8.01, 8.01, 8.0, 7.99, 7.99, 7.99, 7.99, 7.99, 7.98, 7.96, 7.96, 7.95, 7.94, 7.94, 7.93], dtype=pl.Float64),
        pl.Series("Low", [8.05, 8.01, 8.01, 8.02, 8.02, 8.01, 8.01, 8.0, 7.99, 7.99, 7.98, 7.99, 7.99, 7.98, 7.94, 7.96, 7.94, 7.94, 7.94, 7.93], dtype=pl.Float64),
        pl.Series("Close", [8.05, 8.01, 8.02, 8.02, 8.02, 8.01, 8.01, 8.0, 7.99, 7.99, 7.98, 7.99, 7.99, 7.98, 7.94, 7.96, 7.94, 7.94, 7.94, 7.93], dtype=pl.Float64),
        pl.Series("MA14", [8.13, 8.12, 8.11, 8.11, 8.1, 8.09, 8.08, 8.07, 8.06, 8.05, 8.04, 8.03, 8.02, 8.0, 8.0, 7.99, 7.99, 7.98, 7.98, 7.97], dtype=pl.Float64),
        pl.Series("MA28", [8.1, 8.09, 8.09, 8.09, 8.09, 8.09, 8.09, 8.09, 8.09, 8.08, 8.08, 8.08, 8.07, 8.07, 8.06, 8.06, 8.05, 8.04, 8.04, 8.03], dtype=pl.Float64),
        pl.Series("PVT", [0.0, -0.3478, -0.2904, -0.2904, -0.2904, -0.3527, -0.3527, -0.3677, -0.374, -0.374, -0.404, -0.3376, -0.3376, -0.3489, -0.4792, -0.459, -0.6224, -0.6224, -0.6224, -0.6362], dtype=pl.Float64),
    ]
)

and using a previous answer to add a few more columns,

ohlc = df.with_columns(
        DateTime = pl.col("Time").str.to_datetime()
    ).with_columns(
        Date = pl.col("DateTime").dt.date().set_sorted(), 
        t = pl.col("DateTime").dt.time(),
        ones = pl.lit(1),
        vol = (2*(pl.col("High")-pl.col("Low"))/(pl.col("Open")+pl.col("Close"))).round(5),   
    ).with_columns(
        MA = pl.col("Close").rolling_mean(20).over("Date"),
        n = pl.col("ones").cum_sum().over("Date")
    ).select(pl.exclude("ones"))

I want to try and create rolling forward lists using the rolling method:

pl.Config(fmt_str_lengths=100) # increase repr

out = ohlc.rolling(
        index_column = 'n',
        period = '20i',
        offset = '0i',
        group_by = "Date",
    ).agg(
        pl.format("[{}]", pl.col("Close").str.join(",")).alias("lists")
    )

The problem is that print(out.head()) is,

shape: (5, 3)
┌────────────┬─────┬─────────────────────────────────────────────────────────────────────────────────────────────────┐
│ Date       ┆ n   ┆ lists                                                                                           │
│ ---        ┆ --- ┆ ---                                                                                             │
│ date       ┆ i32 ┆ str                                                                                             │
╞════════════╪═════╪═════════════════════════════════════════════════════════════════════════════════════════════════╡
│ 2018-01-02 ┆ 1   ┆ [8.01,8.02,8.02,8.02,8.01,8.01,8.0,7.99,7.99,7.98,7.99,7.99,7.98,7.94,7.96,7.94,7.94,7.94,7.93] │
│ 2018-01-02 ┆ 2   ┆ [8.02,8.02,8.02,8.01,8.01,8.0,7.99,7.99,7.98,7.99,7.99,7.98,7.94,7.96,7.94,7.94,7.94,7.93]      │
│ 2018-01-02 ┆ 3   ┆ [8.02,8.02,8.01,8.01,8.0,7.99,7.99,7.98,7.99,7.99,7.98,7.94,7.96,7.94,7.94,7.94,7.93]           │
│ 2018-01-02 ┆ 4   ┆ [8.02,8.01,8.01,8.0,7.99,7.99,7.98,7.99,7.99,7.98,7.94,7.96,7.94,7.94,7.94,7.93]                │
│ 2018-01-02 ┆ 5   ┆ [8.01,8.01,8.0,7.99,7.99,7.98,7.99,7.99,7.98,7.94,7.96,7.94,7.94,7.94,7.93]                     │
└────────────┴─────┴─────────────────────────────────────────────────────────────────────────────────────────────────┘

so most of the original columns have been removed. Is there a way to retain them? Ideally the output should be something like,

shape: (5, 15)
┌──────────────────┬──────┬──────┬───┬──────┬─────┬─────────────────────────────────┐
│ Time             ┆ Open ┆ High ┆ … ┆ MA   ┆ n   ┆ lists                           │
│ ---              ┆ ---  ┆ ---  ┆   ┆ ---  ┆ --- ┆ ---                             │
│ str              ┆ f64  ┆ f64  ┆   ┆ f64  ┆ i32 ┆ str                             │
╞══════════════════╪══════╪══════╪═══╪══════╪═════╪═════════════════════════════════╡
│ 02/01/2018 07:05 ┆ 8.05 ┆ 8.05 ┆ … ┆ null ┆ 1   ┆ [8.01,8.02,8.02,8.02,8.01,8.01… │
│ 02/01/2018 07:07 ┆ 8.04 ┆ 8.04 ┆ … ┆ null ┆ 2   ┆ [8.02,8.02,8.02,8.01,8.01,8.0,… │
│ 02/01/2018 07:08 ┆ 8.02 ┆ 8.02 ┆ … ┆ null ┆ 3   ┆ [8.02,8.02,8.01,8.01,8.0,7.99,… │
│ 02/01/2018 07:09 ┆ 8.02 ┆ 8.02 ┆ … ┆ null ┆ 4   ┆ [8.02,8.01,8.01,8.0,7.99,7.99,… │
│ 02/01/2018 07:10 ┆ 8.02 ┆ 8.02 ┆ … ┆ null ┆ 5   ┆ [8.01,8.01,8.0,7.99,7.99,7.98,… │
└──────────────────┴──────┴──────┴───┴──────┴─────┴─────────────────────────────────┘

Solution

  • The result of any group_by/agg is going to be the columns you told it that you care about and no more.

    It's not clear in what form you want to retain the original columns as they're grouped but here's the most generic way where each one will have a list of all the original values for its group.

    ohlc.rolling(
            index_column = 'n',
            period = '20i',
            offset = '0i',
            group_by = "Date",
        ).agg(
            pl.format("[{}]", pl.col("Close").str.join(",")).alias("lists"),
            pl.exclude('n','Date')
        )
    shape: (20, 15)
    ┌────────────┬─────┬─────────────────────────────────┬───┬─────────────────────────────────┬───────────────────────────┬────────────────────────┐
    │ Date       ┆ n   ┆ lists                           ┆ … ┆ t                               ┆ vol                       ┆ MA                     │
    │ ---        ┆ --- ┆ ---                             ┆   ┆ ---                             ┆ ---                       ┆ ---                    │
    │ date       ┆ i32 ┆ str                             ┆   ┆ list[time]                      ┆ list[f64]                 ┆ list[f64]              │
    ╞════════════╪═════╪═════════════════════════════════╪═══╪═════════════════════════════════╪═══════════════════════════╪════════════════════════╡
    │ 2018-01-02 ┆ 1   ┆ [8.01,8.02,8.02,8.02,8.01,8.01… ┆ … ┆ [07:07:00, 07:08:00, … 07:57:0… ┆ [0.00374, 0.00125, … 0.0] ┆ [null, null, … 7.9855] │
    │ 2018-01-02 ┆ 2   ┆ [8.02,8.02,8.02,8.01,8.01,8.0,… ┆ … ┆ [07:08:00, 07:09:00, … 07:57:0… ┆ [0.00125, 0.0, … 0.0]     ┆ [null, null, … 7.9855] │
    │ 2018-01-02 ┆ 3   ┆ [8.02,8.02,8.01,8.01,8.0,7.99,… ┆ … ┆ [07:09:00, 07:10:00, … 07:57:0… ┆ [0.0, 0.0, … 0.0]         ┆ [null, null, … 7.9855] │
    │ 2018-01-02 ┆ 4   ┆ [8.02,8.01,8.01,8.0,7.99,7.99,… ┆ … ┆ [07:10:00, 07:12:00, … 07:57:0… ┆ [0.0, 0.0, … 0.0]         ┆ [null, null, … 7.9855] │
    │ 2018-01-02 ┆ 5   ┆ [8.01,8.01,8.0,7.99,7.99,7.98,… ┆ … ┆ [07:12:00, 07:13:00, … 07:57:0… ┆ [0.0, 0.0, … 0.0]         ┆ [null, null, … 7.9855] │
    │ …          ┆ …   ┆ …                               ┆ … ┆ …                               ┆ …                         ┆ …                      │
    │ 2018-01-02 ┆ 16  ┆ [7.94,7.94,7.94,7.93]           ┆ … ┆ [07:52:00, 07:53:00, … 07:57:0… ┆ [0.00126, 0.0, … 0.0]     ┆ [null, null, … 7.9855] │
    │ 2018-01-02 ┆ 17  ┆ [7.94,7.94,7.93]                ┆ … ┆ [07:53:00, 07:56:00, 07:57:00]  ┆ [0.0, 0.0, 0.0]           ┆ [null, null, 7.9855]   │
    │ 2018-01-02 ┆ 18  ┆ [7.94,7.93]                     ┆ … ┆ [07:56:00, 07:57:00]            ┆ [0.0, 0.0]                ┆ [null, 7.9855]         │
    │ 2018-01-02 ┆ 19  ┆ [7.93]                          ┆ … ┆ [07:57:00]                      ┆ [0.0]                     ┆ [7.9855]               │
    │ 2018-01-02 ┆ 20  ┆ []                              ┆ … ┆ []                              ┆ []                        ┆ []                     │
    └────────────┴─────┴─────────────────────────────────┴───┴─────────────────────────────────┴───────────────────────────┴────────────────────────┘
    

    where pl.exclude means all the columns except whatever you put there. It's a shortcut for pl.all().exclude() which makes more intuitive sense but it more typing. Since you're already asking for 'n' and 'Date' as part of the group_by you don't want to ask for them again with pl.all().

    Based on the edit, you should do a self-join to restore the columns.

    ohlc.join(ohlc.rolling(
            index_column = 'n',
            period = '20i',
            offset = '0i',
            group_by = "Date",
        ).agg(
            pl.format("[{}]", pl.col("Close").str.join(",")).alias("lists")
        ), on=['Date','n'])
    
    shape: (20, 15)
    ┌──────────────────┬──────┬──────┬───┬────────┬─────┬─────────────────────────────────┐
    │ Time             ┆ Open ┆ High ┆ … ┆ MA     ┆ n   ┆ lists                           │
    │ ---              ┆ ---  ┆ ---  ┆   ┆ ---    ┆ --- ┆ ---                             │
    │ str              ┆ f64  ┆ f64  ┆   ┆ f64    ┆ i32 ┆ str                             │
    ╞══════════════════╪══════╪══════╪═══╪════════╪═════╪═════════════════════════════════╡
    │ 02/01/2018 07:05 ┆ 8.05 ┆ 8.05 ┆ … ┆ null   ┆ 1   ┆ [8.01,8.02,8.02,8.02,8.01,8.01… │
    │ 02/01/2018 07:07 ┆ 8.04 ┆ 8.04 ┆ … ┆ null   ┆ 2   ┆ [8.02,8.02,8.02,8.01,8.01,8.0,… │
    │ 02/01/2018 07:08 ┆ 8.02 ┆ 8.02 ┆ … ┆ null   ┆ 3   ┆ [8.02,8.02,8.01,8.01,8.0,7.99,… │
    │ 02/01/2018 07:09 ┆ 8.02 ┆ 8.02 ┆ … ┆ null   ┆ 4   ┆ [8.02,8.01,8.01,8.0,7.99,7.99,… │
    │ 02/01/2018 07:10 ┆ 8.02 ┆ 8.02 ┆ … ┆ null   ┆ 5   ┆ [8.01,8.01,8.0,7.99,7.99,7.98,… │
    │ …                ┆ …    ┆ …    ┆ … ┆ …      ┆ …   ┆ …                               │
    │ 02/01/2018 07:50 ┆ 7.96 ┆ 7.96 ┆ … ┆ null   ┆ 16  ┆ [7.94,7.94,7.94,7.93]           │
    │ 02/01/2018 07:52 ┆ 7.95 ┆ 7.95 ┆ … ┆ null   ┆ 17  ┆ [7.94,7.94,7.93]                │
    │ 02/01/2018 07:53 ┆ 7.94 ┆ 7.94 ┆ … ┆ null   ┆ 18  ┆ [7.94,7.93]                     │
    │ 02/01/2018 07:56 ┆ 7.94 ┆ 7.94 ┆ … ┆ null   ┆ 19  ┆ [7.93]                          │
    │ 02/01/2018 07:57 ┆ 7.93 ┆ 7.93 ┆ … ┆ 7.9855 ┆ 20  ┆ []                              │
    └──────────────────┴──────┴──────┴───┴────────┴─────┴─────────────────────────────────┘