Search code examples
pythonpython-polars

Using Polars, how do I do efficiently do an `over` that collects items into a list?


As a simple example, consider the following, using groupby:

import polars as pl

df = pl.DataFrame(
    [pl.Series("id", ["a", "b", "a"]), pl.Series("x", [0, 1, 2])]
)
print(df.group_by("id").agg(pl.col("x")))

# shape: (2, 2)
# ┌─────┬───────────┐
# │ id  ┆ x         │
# │ --- ┆ ---       │
# │ str ┆ list[i64] │
# ╞═════╪═══════════╡
# │ b   ┆ [1]       │
# │ a   ┆ [0, 2]    │
# └─────┴───────────┘

But if we use over, we get:

import polars as pl

df = pl.DataFrame(
    [pl.Series("id", ["a", "b", "a"]), pl.Series("x", [0, 1, 2])]
)
print(df.with_columns(pl.col("x").over("id")))
# shape: (3, 2)
# ┌─────┬─────┐
# │ id  ┆ x   │
# │ --- ┆ --- │
# │ str ┆ i64 │
# ╞═════╪═════╡
# │ a   ┆ 0   │
# │ b   ┆ 1   │
# │ a   ┆ 2   │
# └─────┴─────┘

How can the groupby result be achieved using over? Well, using mapping_strategy="join".

A slightly more complicated example, meant to showcase why we might want to use over instead of groupby:

import polars as pl

# the smallest value a Float32 can encode is 1e-38
# therefore, as far as we are concerned,
# 1e-41 and 1e-42 should be indistinguishable
# in other words, we do not want to use "other" as an id column
# but we do want to preserve other!
df = pl.DataFrame(
    [
        pl.Series("id", ["a", "b", "a"]),
        pl.Series("other", [1e-41, 1e-16, 1e-42], dtype=pl.Float32()),
        pl.Series("x", [0, 1, 2]),
    ]
)
print(df.group_by("id").agg(pl.col("x"), pl.col("other")).explode("other"))

# shape: (3, 3)
# ┌─────┬───────────┬────────────┐
# │ id  ┆ x         ┆ other      │
# │ --- ┆ ---       ┆ ---        │
# │ str ┆ list[i64] ┆ f32        │
# ╞═════╪═══════════╪════════════╡
# │ a   ┆ [0, 2]    ┆ 9.9997e-42 │
# │ a   ┆ [0, 2]    ┆ 1.0005e-42 │
# │ b   ┆ [1]       ┆ 1.0000e-16 │
# └─────┴───────────┴────────────┘

Now, using over:

import polars as pl

# the smallest value a Float32 can encode is 1e-38
# therefore, as far as we are concerned,
# 1e-41 and 1e-42 should be indistinguishable
# in other words, we do not want to use "other" as an id column
# but we do want to preserve other!
df = pl.DataFrame(
    [
        pl.Series("id", ["a", "b", "a"]),
        pl.Series("other", [1e-41, 1e-16, 1e-42], dtype=pl.Float32()),
        pl.Series("x", [0, 1, 2]),
    ]
)
print(df.with_columns(pl.col("x").over(["id"], mapping_strategy="join")))

# shape: (3, 3)
# ┌─────┬────────────┬───────────┐
# │ id  ┆ other      ┆ x         │
# │ --- ┆ ---        ┆ ---       │
# │ str ┆ f32        ┆ list[i64] │
# ╞═════╪════════════╪═══════════╡
# │ a   ┆ 9.9997e-42 ┆ [0, 2]    │
# │ b   ┆ 1.0000e-16 ┆ [1]       │
# │ a   ┆ 1.0005e-42 ┆ [0, 2]    │
# └─────┴────────────┴───────────┘

The trouble using mapping_strategy="join" is that its very slow. So, this suggests that I ought to do a group_by followed by a join:

import polars as pl
import polars.selectors as cs

# the smallest value a Float32 can encode is 1e-38
# therefore, as far as we are concerned,
# 1e-41 and 1e-42 should be indistinguishable
# in other words, we do not want to use "other" as an id column
# but we do want to preserve other!
df = pl.DataFrame(
    [
        pl.Series("id", ["a", "b", "a"]),
        pl.Series("other", [1e-41, 1e-16, 1e-42], dtype=pl.Float32()),
        pl.Series("x", [0, 1, 2]),
    ]
)


print(
    df.select(cs.exclude("x")).join(
        df.group_by("id").agg("x"),
        on="id",
        # we expect there to be multiple "id"s on the left, matching
        # a single "id" on the right
        validate="m:1",
    )
)

# shape: (3, 3)
# ┌─────┬────────────┬───────────┐
# │ id  ┆ other      ┆ x         │
# │ --- ┆ ---        ┆ ---       │
# │ str ┆ f32        ┆ list[i64] │
# ╞═════╪════════════╪═══════════╡
# │ a   ┆ 9.9997e-42 ┆ [0, 2]    │
# │ b   ┆ 1.0000e-16 ┆ [1]       │
# │ a   ┆ 1.0005e-42 ┆ [0, 2]    │
# └─────┴────────────┴───────────┘

But perhaps I am missing something else about over?


Solution

  • Currently, Polars does not distinguish between Scalars and Series. Any Series with length 1 is considered a scalar and will be broadcasted when combined with other series.

    We are actively working on distinguishing these two concepts better, and once we have done so I would expect list aggregations like Expr.implode() in over contexts to work as you'd expect. That is, the following should solve your problem, but currently doesn't:

    >>> df.with_columns(pl.col.x.implode().over("id")))
    

    To do it currently I would suggest a normal aggregation + join:

    >>> df.drop("x").join(df.group_by("id").agg("x"), on="id")
    shape: (3, 3)
    ┌─────┬────────────┬───────────┐
    │ id  ┆ other      ┆ x         │
    │ --- ┆ ---        ┆ ---       │
    │ str ┆ f32        ┆ list[i64] │
    ╞═════╪════════════╪═══════════╡
    │ a   ┆ 9.9997e-42 ┆ [0, 2]    │
    │ b   ┆ 1.0000e-16 ┆ [1]       │
    │ a   ┆ 1.0005e-42 ┆ [0, 2]    │
    └─────┴────────────┴───────────┘