This is a new question/issue as a follow up to How to return multiple stats as multiple columns in Polars grouby context? and How to flatten/split a tuple of arrays and calculate column means in Polars dataframe?
Basically, the problem/issue can be easily illustrated by the example below:
from functools import partial
import polars as pl
import statsmodels.api as sm
def ols_stats(s, yvar, xvars):
df = s.struct.unnest()
reg = sm.OLS(df[yvar].to_numpy(), df[xvars].to_numpy(), missing="drop").fit()
return pl.Series(values=(reg.params, reg.tvalues), nan_to_null=True)
df = pl.DataFrame(
{
"day": [1, 1, 1, 1, 1, 2, 2, 2, 2, 2],
"y": [1, 6, 3, 2, 8, 4, 5, 2, 7, 3],
"x1": [1, 8, 2, 3, 5, 2, 1, 2, 7, 3],
"x2": [8, 5, 3, 6, 3, 7, 3, 2, 9, 1],
}
).lazy()
res = df.group_by("day").agg(
pl.struct("y", "x1", "x2")
.map_elements(partial(ols_stats, yvar="y", xvars=["x1", "x2"]))
.alias("params")
)
res.with_columns(
pl.col("params").list.eval(pl.element().list.explode()).list.to_struct()
).unnest("params").collect()
After running the code above, the following error is got:
PanicException: expected known type
But when .lazy()
and .collect()
are removed from the code above, the code works perfectly as intended. Below are the results (expected behavior) if running in eager mode.
shape: (2, 5)
┌─────┬──────────┬──────────┬──────────┬───────────┐
│ day ┆ field_0 ┆ field_1 ┆ field_2 ┆ field_3 │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ f64 ┆ f64 ┆ f64 ┆ f64 │
╞═════╪══════════╪══════════╪══════════╪═══════════╡
│ 2 ┆ 0.466089 ┆ 0.503127 ┆ 0.916982 ┆ 1.451151 │
│ 1 ┆ 1.008659 ┆ -0.03324 ┆ 3.204266 ┆ -0.124422 │
└─────┴──────────┴──────────┴──────────┴───────────┘
So, what is the problem and how am I supposed to resolve it?
Don't return a Series
from ols_stats()
but a dict
then it should work. This is also semantically better as the struct you show in the end is a mess: the first two fields mean params
, the second two fields mean tvalues
. Try this instead:
def ols_stats(s, yvar, xvars):
df = s.struct.unnest()
reg = sm.OLS(df[yvar].to_numpy(), df[xvars].to_numpy(), missing="drop").fit()
return {"params":reg.params.tolist(),"tvalues":reg.tvalues.tolist()}
Polars automatically turns the dict[list[f64]]
into a struct[2]
. I had to play around a bit to figure this out but it seems to work.
This way you end up with semantically meaningful results:
shape: (3, 3)
┌─────┬─────────────────────────────────┬────────────────────────────────┐
│ day ┆ params ┆ tvalues │
│ --- ┆ --- ┆ --- │
│ i64 ┆ list[f64] ┆ list[f64] │
╞═════╪═════════════════════════════════╪════════════════════════════════╡
│ 1 ┆ [4.866232, 0.640294, -0.659869] ┆ [1.547251, 1.81586, -1.430613] │
│ 3 ┆ [0.5, 0.5] ┆ [0.0, 0.0] │
│ 2 ┆ [2.0462, 0.223971, 0.336793] ┆ [1.524834, 0.495378, 1.091109] │
└─────┴─────────────────────────────────┴────────────────────────────────┘
Now it works lazily:
df.group_by("day").agg(
pl.struct("y", "x1", "x2")
.map_elements(partial(ols_stats, yvar="y", xvars=["x1", "x2"]))
.alias("params")
).unnest("params").collect()
If you want things unnested, why not return them unnested immediately as:
def ols_stats(s, yvar, xvars):
df = s.struct.unnest()
reg = sm.OLS(df[yvar].to_numpy(), df[xvars].to_numpy(), missing="drop").fit()
param_dict = {f"param_{i}": v for i, v in enumerate(reg.params.tolist())}
tvalues_dict = {f"tvalue_{i}": v for i, v in enumerate(reg.tvalues.tolist())}
return (param_dict | tvalues_dict)
df = pl.LazyFrame(
{
"day": [1, 1, 1, 1, 1, 2, 2, 2, 2, 2],
"y": [1, 6, 3, 2, 8, 4, 5, 2, 7, 3],
"x1": [1, 8, 2, 3, 5, 2, 1, 2, 7, 3],
"x2": [8, 5, 3, 6, 3, 7, 3, 2, 9, 1],
}
)
res = df.group_by("day").agg(
pl.struct("y", "x1", "x2")
.map_elements(partial(ols_stats, yvar="y", xvars=["x1", "x2"]))
.alias("results")
).unnest("results").collect()
print(res)
Returns:
shape: (2, 5)
┌─────┬──────────┬──────────┬──────────┬───────────┐
│ day ┆ param_0 ┆ param_1 ┆ tvalue_0 ┆ tvalue_1 │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ f64 ┆ f64 ┆ f64 ┆ f64 │
╞═════╪══════════╪══════════╪══════════╪═══════════╡
│ 1 ┆ 1.008659 ┆ -0.03324 ┆ 3.204266 ┆ -0.124422 │
│ 2 ┆ 0.466089 ┆ 0.503127 ┆ 0.916982 ┆ 1.451151 │
└─────┴──────────┴──────────┴──────────┴───────────┘