Search code examples
pythondataframeoptimizationindexingpython-polars

Optimize assigning an index to groups of split data in Polars


SOLVED:

Fastest Function: 995x faster than original function

def add_range_index_stack2(data, range_str):
    range_str = _range_format(range_str)
    df_range_index = (
        data.group_by_dynamic(index_column="date", every=range_str, by="symbol")
        .agg()
        .with_columns(
            pl.int_range(0, pl.len()).over("symbol").alias("range_index")
        )
    )
    data = data.join_asof(df_range_index, on="date", by="symbol")

    return data

OG Question:

Data Logic:

I have a time series that needs to be split into chunks.

Let's say it needs to be split into 3 chunks for this post. The data I am using is stock quote data and daily prices. If the length of the time series data is 3 months, and the 'split' range is 1 month, then there should be 3 chunks of data, each month labeled with increasing integers. So, there should be 3 sections in the time series, all in one data frame. There should be a column named range_index that starts at 0 and iterates until 2. For example, if the data was January-March data, each price quote should be labeled 0, 1, or 2. 0 for January, 1 for February, and 2 for March data.

I would like for this to be done for each symbol in the data frame. The start_date of each symbol may not be the same, so it should be robust in that way, and correctly assign range_index values based on the symbol stock data.

What I've Done:

I have built a function using polar logic that adds a column onto this data frame, but I think that there are possibly faster ways to do this. When I add a few symbols with a few years of data, it slows down to about ~3s to execute.

I would love any advice on how to speed up the function, or even a novel approach. I'm aware that row-based operations are slower in polar than columnar. If there are any polars nerds out there that see a glaring issue....please help!

def add_range_index(
    data: pl.LazyFrame | pl.DataFrame, range_str: str
) -> pl.LazyFrame | pl.DataFrame:
    """
    Context: Toolbox || Category: Helpers || Sub-Category: Mandelbrot Channel Helpers || **Command: add_n_range**.

    This function is used to add a column to the dataframe that contains the
    range grouping for the entire time series.


    This function is used in `log_mean()`
    """  # noqa: W505
    range_str = _range_format(range_str)

    if "date" in data.columns:
        group_by_args = {
            "every": range_str,
            "closed": "left",
            "include_boundaries": True,
        }
        if "symbol" in data.columns:
            group_by_args["by"] = "symbol"
            symbols = (data.select("symbol").unique().count())["symbol"][0]

        grouped_data = (
            data.lazy()
            .set_sorted("date")
            .group_by_dynamic("date", **group_by_args)
            .agg(
                pl.col("adj_close").count().alias("n_obs")
            )  # using 'adj_close' as the column to sum
        )
    range_row = grouped_data.with_columns(
        pl.arange(0, pl.count()).over("symbol").alias("range_index")
    )
    ## WIP:
    # Extract the number of ranges the time series has

    # Initialize a new column to store the range index
    data = data.with_columns(pl.lit(None).alias("range_index"))

    # Loop through each range and add the range index to the original dataframe
    for row in range_row.collect().to_dicts():
        symbol = row["symbol"]
        start_date = row["_lower_boundary"]
        end_date = row["_upper_boundary"]
        range_index = row["range_index"]

        # Apply the conditional logic to each group defined by the 'symbol' column
        data = data.with_columns(
            pl.when(
                (pl.col("date") >= start_date)
                & (pl.col("date") < end_date)
                & (pl.col("symbol") == symbol)
            )
            .then(range_index)
            .otherwise(pl.col("range_index"))
            .over("symbol")  # Apply the logic over each 'symbol' group
            .alias("range_index")
        )

    return data


def _range_format(range_str: str) -> str:
    """
    Context: Toolbox || Category: Technical || Sub-Category: Mandelbrot Channel Helpers || **Command: _range_format**.

    This function formats a range string into a standard format.
    The return value is to be passed to `_range_days()`.

    Parameters
    ----------
    range_str : str
        The range string to format. It should contain a number followed by a
        range part. The range part can be 'day', 'week', 'month', 'quarter', or
        'year'. The range part can be in singular or plural form and can be
        abbreviated. For example, '2 weeks', '2week', '2wks', '2wk', '2w' are
        all valid.

    Returns
    -------
    str
        The formatted range string. The number is followed by an abbreviation of
        the range part ('d' for day, 'w' for week, 'mo' for month, 'q' for
        quarter, 'y' for year). For example, '2 weeks' is formatted as '2w'.

    Raises
    ------
    RangeFormatError
        If an invalid range part is provided.

    Notes
    -----
    This function is used in `log_mean()`
    """  # noqa: W505
    # Separate the number and range part
    num = "".join(filter(str.isdigit, range_str))

    # Find the first character after the number in the range string
    range_part = next((char for char in range_str if char.isalpha()), None)

    # Check if the range part is a valid abbreviation
    if range_part not in {"d", "w", "m", "y", "q"}:
        msg = f"`{range_str}` could not be formatted; needs to include d, w, m, y, q"
        raise HumblDataError(msg)

    # If the range part is "m", replace it with "mo" to represent "month"
    if range_part == "m":
        range_part = "mo"

    # Return the formatted range string
    return num + range_part

Expected Data Form:

Expected Output Data Expected Data Output cont.

The same is done for PCT stock symbol.


Solution

  • Another solution is to create a separate DataFrame that for each symbol and range index stores the corresponding start date.

    df_range_index = (
        df
        .group_by_dynamic(index_column="date", every="1w", by="symbol").agg()
        .with_columns(pl.int_range(0, pl.len()).over("symbol").alias("range_index"))
    )
    
    shape: (106, 3)
    ┌────────┬────────────┬─────────────┐
    │ symbol ┆ date       ┆ range_index │
    │ ---    ┆ ---        ┆ ---         │
    │ str    ┆ date       ┆ i64         │
    ╞════════╪════════════╪═════════════╡
    │ AAPL   ┆ 2022-12-26 ┆ 0           │
    │ AAPL   ┆ 2023-01-02 ┆ 1           │
    │ …      ┆ …          ┆ …           │
    │ PCT    ┆ 2023-12-11 ┆ 50          │
    │ PCT    ┆ 2023-12-18 ┆ 51          │
    │ PCT    ┆ 2023-12-25 ┆ 52          │
    └────────┴────────────┴─────────────┘
    

    We can then use pl.DataFrame.join_asof to merge the range index to the original dataframe.

    df.join_asof(df_range_index, on="date", by="symbol")
    

    Edit. As suggested by @jqurious, it might be possible to represented the range index as a simple truncation (+ some offset) of the date. Then, we can use .dt.truncate and map the date groups to ids using .rle_id.

    (
        df
        .with_columns(
            range_index=pl.col("date").dt.truncate(every="1w").rle_id().over("symbol")
        )
    )