I have a numba UDF:
@numba.jit(nopython=True)
def generate_sample_numba(cumulative_dollar_volume: np.ndarray, dollar_tau: Union[int, np.ndarray]) -> np.ndarray:
""" Generate the sample using numba for speed.
"""
covered_dollar_volume = 0
bar_index = 0
bar_index_array = np.zeros_like(cumulative_dollar_volume, dtype=np.uint32)
if isinstance(dollar_tau, int):
dollar_tau = np.array([dollar_tau] * len(cumulative_dollar_volume))
for i in range(len(cumulative_dollar_volume)):
bar_index_array[i] = bar_index
if cumulative_dollar_volume[i] >= covered_dollar_volume + dollar_tau[i]:
bar_index += 1
covered_dollar_volume = cumulative_dollar_volume[i]
return bar_index_array
The UDF takes two inputs:
cumulative_dollar_volume
numpy array, which is essentially the groups in group_by
dollar_tau
threshold, which is either an integer or numpy array.In this question, I am particularly interested in the numpy array configuration. This post well explains the idea behind the generat_sample_numba
function.
I want to achieve the same results from Pandas by using polars:
data["bar_index"] = data.groupby(["ticker", "date"]).apply(lambda x: generate_sample_numba(x["cumulative_dollar_volume"].values, x["dollar_tau"].values)).explode().values.astype(int)
Apprently, the best option in Polars is by group_by().agg(pl.col().map_batehces()
:
cqt_sample = cqt_sample.with_columns(
(pl.col("price") * pl.col("size")).alias("dollar_volume")).with_columns(
pl.col("dollar_volume").cum_sum().over(["ticker", "date"]).alias("cumulative_dollar_volume"),
pl.lit(1_000_000).alias("dollar_tau")
)
(cqt_sample
.group_by(["ticker", "date"])
.agg(pl.col(["cumulative_dollar_volume", "dollar_tau"])
.map_batches(lambda x: generate_sample_numba(x["cumulative_dollar_volume"].to_numpy(), 1_000_000))
)#.alias("bar_index")
)#.explode("bar_index")
but the map_bathces()
method seems to throw some strange results.`
However, when I use the integer dollar_tau
with one input column it works fine:
(cqt_sample
.group_by(["ticker", "date"])
.agg(pl.col("cumulative_dollar_volume")
.map_batches(lambda x: generate_sample_numba(x.to_numpy(), 1_000_000))
).alias("bar_index")
).explode("bar_index")
As suggested in the comments, you'll need to call pl.Expr.map_batches
on a struct column that contains all information needed by the function. Inside the function, you then pick the struct apart to obtain the desired information.
(
data
.group_by(["ticker", "date"])
.agg(
pl.struct("cumulative_dollar_volume", "dollar_tau").map_batches(lambda x: \
generate_sample_numba(
x.struct.field("cumulative_dollar_volume").to_numpy(),
dollar_tau=x.struct.field("dollar_tau").to_numpy()
)
)
.alias("bar_index")
)
).explode("bar_index")