As a simple example, consider the following, using groupby
:
import polars as pl
df = pl.DataFrame(
[pl.Series("id", ["a", "b", "a"]), pl.Series("x", [0, 1, 2])]
)
print(df.group_by("id").agg(pl.col("x")))
# shape: (2, 2)
# ┌─────┬───────────┐
# │ id ┆ x │
# │ --- ┆ --- │
# │ str ┆ list[i64] │
# ╞═════╪═══════════╡
# │ b ┆ [1] │
# │ a ┆ [0, 2] │
# └─────┴───────────┘
But if we use over
, we get:
import polars as pl
df = pl.DataFrame(
[pl.Series("id", ["a", "b", "a"]), pl.Series("x", [0, 1, 2])]
)
print(df.with_columns(pl.col("x").over("id")))
# shape: (3, 2)
# ┌─────┬─────┐
# │ id ┆ x │
# │ --- ┆ --- │
# │ str ┆ i64 │
# ╞═════╪═════╡
# │ a ┆ 0 │
# │ b ┆ 1 │
# │ a ┆ 2 │
# └─────┴─────┘
How can the groupby
result be achieved using over
? Well, using mapping_strategy="join"
.
A slightly more complicated example, meant to showcase why we might want to use over
instead of groupby
:
import polars as pl
# the smallest value a Float32 can encode is 1e-38
# therefore, as far as we are concerned,
# 1e-41 and 1e-42 should be indistinguishable
# in other words, we do not want to use "other" as an id column
# but we do want to preserve other!
df = pl.DataFrame(
[
pl.Series("id", ["a", "b", "a"]),
pl.Series("other", [1e-41, 1e-16, 1e-42], dtype=pl.Float32()),
pl.Series("x", [0, 1, 2]),
]
)
print(df.group_by("id").agg(pl.col("x"), pl.col("other")).explode("other"))
# shape: (3, 3)
# ┌─────┬───────────┬────────────┐
# │ id ┆ x ┆ other │
# │ --- ┆ --- ┆ --- │
# │ str ┆ list[i64] ┆ f32 │
# ╞═════╪═══════════╪════════════╡
# │ a ┆ [0, 2] ┆ 9.9997e-42 │
# │ a ┆ [0, 2] ┆ 1.0005e-42 │
# │ b ┆ [1] ┆ 1.0000e-16 │
# └─────┴───────────┴────────────┘
Now, using over
:
import polars as pl
# the smallest value a Float32 can encode is 1e-38
# therefore, as far as we are concerned,
# 1e-41 and 1e-42 should be indistinguishable
# in other words, we do not want to use "other" as an id column
# but we do want to preserve other!
df = pl.DataFrame(
[
pl.Series("id", ["a", "b", "a"]),
pl.Series("other", [1e-41, 1e-16, 1e-42], dtype=pl.Float32()),
pl.Series("x", [0, 1, 2]),
]
)
print(df.with_columns(pl.col("x").over(["id"], mapping_strategy="join")))
# shape: (3, 3)
# ┌─────┬────────────┬───────────┐
# │ id ┆ other ┆ x │
# │ --- ┆ --- ┆ --- │
# │ str ┆ f32 ┆ list[i64] │
# ╞═════╪════════════╪═══════════╡
# │ a ┆ 9.9997e-42 ┆ [0, 2] │
# │ b ┆ 1.0000e-16 ┆ [1] │
# │ a ┆ 1.0005e-42 ┆ [0, 2] │
# └─────┴────────────┴───────────┘
The trouble using mapping_strategy="join"
is that its very slow. So, this suggests that I ought to do a group_by
followed by a join
:
import polars as pl
import polars.selectors as cs
# the smallest value a Float32 can encode is 1e-38
# therefore, as far as we are concerned,
# 1e-41 and 1e-42 should be indistinguishable
# in other words, we do not want to use "other" as an id column
# but we do want to preserve other!
df = pl.DataFrame(
[
pl.Series("id", ["a", "b", "a"]),
pl.Series("other", [1e-41, 1e-16, 1e-42], dtype=pl.Float32()),
pl.Series("x", [0, 1, 2]),
]
)
print(
df.select(cs.exclude("x")).join(
df.group_by("id").agg("x"),
on="id",
# we expect there to be multiple "id"s on the left, matching
# a single "id" on the right
validate="m:1",
)
)
# shape: (3, 3)
# ┌─────┬────────────┬───────────┐
# │ id ┆ other ┆ x │
# │ --- ┆ --- ┆ --- │
# │ str ┆ f32 ┆ list[i64] │
# ╞═════╪════════════╪═══════════╡
# │ a ┆ 9.9997e-42 ┆ [0, 2] │
# │ b ┆ 1.0000e-16 ┆ [1] │
# │ a ┆ 1.0005e-42 ┆ [0, 2] │
# └─────┴────────────┴───────────┘
But perhaps I am missing something else about over
?
Currently, Polars does not distinguish between Scalar
s and Series
. Any Series
with length 1 is considered a scalar and will be broadcasted when combined with other series.
We are actively working on distinguishing these two concepts better, and once we have done so I would expect list aggregations like Expr.implode()
in over
contexts to work as you'd expect. That is, the following should solve your problem, but currently doesn't:
>>> df.with_columns(pl.col.x.implode().over("id")))
To do it currently I would suggest a normal aggregation + join:
>>> df.drop("x").join(df.group_by("id").agg("x"), on="id")
shape: (3, 3)
┌─────┬────────────┬───────────┐
│ id ┆ other ┆ x │
│ --- ┆ --- ┆ --- │
│ str ┆ f32 ┆ list[i64] │
╞═════╪════════════╪═══════════╡
│ a ┆ 9.9997e-42 ┆ [0, 2] │
│ b ┆ 1.0000e-16 ┆ [1] │
│ a ┆ 1.0005e-42 ┆ [0, 2] │
└─────┴────────────┴───────────┘