Search code examples
pythonpandasnumpynumbapython-polars

Polars groupby map UDF using multiple columns as parameter


I have a numba UDF:

@numba.jit(nopython=True)
def generate_sample_numba(cumulative_dollar_volume: np.ndarray, dollar_tau: Union[int, np.ndarray]) -> np.ndarray:
        """ Generate the sample using numba for speed.
        """
        covered_dollar_volume = 0
        bar_index = 0
        bar_index_array = np.zeros_like(cumulative_dollar_volume, dtype=np.uint32)
        
        if isinstance(dollar_tau, int):
            dollar_tau = np.array([dollar_tau] * len(cumulative_dollar_volume))

        for i in range(len(cumulative_dollar_volume)):
            bar_index_array[i] = bar_index
            if cumulative_dollar_volume[i] >= covered_dollar_volume + dollar_tau[i]:
                bar_index += 1
                covered_dollar_volume = cumulative_dollar_volume[i]
        return bar_index_array

The UDF takes two inputs:

  1. The cumulative_dollar_volume numpy array, which is essentially the groups in group_by
  2. The dollar_tau threshold, which is either an integer or numpy array.

In this question, I am particularly interested in the numpy array configuration. This post well explains the idea behind the generat_sample_numba function.

I want to achieve the same results from Pandas by using polars:

data["bar_index"] = data.groupby(["ticker", "date"]).apply(lambda x: generate_sample_numba(x["cumulative_dollar_volume"].values, x["dollar_tau"].values)).explode().values.astype(int)

Apprently, the best option in Polars is by group_by().agg(pl.col().map_batehces():

cqt_sample = cqt_sample.with_columns(
    (pl.col("price") * pl.col("size")).alias("dollar_volume")).with_columns(
    pl.col("dollar_volume").cum_sum().over(["ticker", "date"]).alias("cumulative_dollar_volume"),
    pl.lit(1_000_000).alias("dollar_tau")
    )

(cqt_sample
    .group_by(["ticker", "date"])
    .agg(pl.col(["cumulative_dollar_volume", "dollar_tau"])
         .map_batches(lambda x: generate_sample_numba(x["cumulative_dollar_volume"].to_numpy(), 1_000_000))
                      )#.alias("bar_index")
                      )#.explode("bar_index")

but the map_bathces() method seems to throw some strange results.`

However, when I use the integer dollar_tau with one input column it works fine:

(cqt_sample
    .group_by(["ticker", "date"])
    .agg(pl.col("cumulative_dollar_volume")
         .map_batches(lambda x: generate_sample_numba(x.to_numpy(), 1_000_000))
                      ).alias("bar_index")
                      ).explode("bar_index")

Solution

  • As suggested in the comments, you'll need to call pl.Expr.map_batches on a struct column that contains all information needed by the function. Inside the function, you then pick the struct apart to obtain the desired information.

    (
        data
        .group_by(["ticker", "date"])
        .agg(
            pl.struct("cumulative_dollar_volume", "dollar_tau").map_batches(lambda x: \
                generate_sample_numba(
                    x.struct.field("cumulative_dollar_volume").to_numpy(),
                    dollar_tau=x.struct.field("dollar_tau").to_numpy()
                )
            )
            .alias("bar_index")
        )
    ).explode("bar_index")