I have a numba UDF:
@numba.jit(nopython=True)
def generate_sample_numba(cumulative_dollar_volume: np.ndarray, dollar_tau: Union[int, np.ndarray]) -> np.ndarray:
""" Generate the sample using numba for speed.
"""
covered_dollar_volume = 0
bar_index = 0
bar_index_array = np.zeros_like(cumulative_dollar_volume, dtype=np.uint32)
if isinstance(dollar_tau, int):
dollar_tau = np.array([dollar_tau] * len(cumulative_dollar_volume))
for i in range(len(cumulative_dollar_volume)):
bar_index_array[i] = bar_index
if cumulative_dollar_volume[i] >= covered_dollar_volume + dollar_tau[i]:
bar_index += 1
covered_dollar_volume = cumulative_dollar_volume[i]
return bar_index_array
The UDF takes two inputs:
- The
cumulative_dollar_volumenumpy array, which is essentially the groups ingroup_by - The
dollar_tauthreshold, which is either an integer or numpy array.
In this question, I am particularly interested in the numpy array configuration. This post well explains the idea behind the generat_sample_numba function.
I want to achieve the same results from Pandas by using polars:
data["bar_index"] = data.groupby(["ticker", "date"]).apply(lambda x: generate_sample_numba(x["cumulative_dollar_volume"].values, x["dollar_tau"].values)).explode().values.astype(int)
Apprently, the best option in Polars is by group_by().agg(pl.col().map_batehces():
cqt_sample = cqt_sample.with_columns(
(pl.col("price") * pl.col("size")).alias("dollar_volume")).with_columns(
pl.col("dollar_volume").cum_sum().over(["ticker", "date"]).alias("cumulative_dollar_volume"),
pl.lit(1_000_000).alias("dollar_tau")
)
(cqt_sample
.group_by(["ticker", "date"])
.agg(pl.col(["cumulative_dollar_volume", "dollar_tau"])
.map_batches(lambda x: generate_sample_numba(x["cumulative_dollar_volume"].to_numpy(), 1_000_000))
)#.alias("bar_index")
)#.explode("bar_index")
but the map_bathces() method seems to throw some strange results.`
However, when I use the integer dollar_tau with one input column it works fine:
(cqt_sample
.group_by(["ticker", "date"])
.agg(pl.col("cumulative_dollar_volume")
.map_batches(lambda x: generate_sample_numba(x.to_numpy(), 1_000_000))
).alias("bar_index")
).explode("bar_index")
cqt_samplein order to make your example runnable, along with the expected output?pl.col("a", "b").map_batchesis shorthand forpl.col("a").map_batches, pl.col("b").map_batches- but you want to pass multiple items at the same time so would likely need a struct:pl.struct("a", "b").map_batches