Search code examples
pythonpython-polars

group_by result is inconsistent in polars


Based on the example from the Aggregation section of the User Guide

import polars as pl
from datetime import date

def compute_age() -> pl.Expr:
    return date(2021, 1, 1).year - pl.col("birthday").dt.year()

def avg_birthday(gender: str) -> pl.Expr:
    return compute_age().filter(
            pl.col("gender") == gender
        ).mean().alias(f"avg {gender} birthday")


dataset = pl.read_csv(b"""
state,gender,birthday
GA,M,1861-06-06
MD,F,1920-09-17
PA,M,1778-10-13
KS,M,1926-02-23
CO,M,1959-02-16
IL,F,1937-08-15
NY,M,1803-04-30
TX,F,1935-12-03
MD,M,1756-06-03
OH,M,1786-11-15
""".strip(), try_parse_dates=True).lazy()

q = (
    dataset
    .group_by("state")
    .agg(
        avg_birthday("M"), 
        avg_birthday("F"),
        (pl.col("gender") == "M").count().alias("# male"), 
        (pl.col("gender") == "F").sum().alias("# female"),
    )
)

The result is inconsistent. For example, the first time I run q.collect().head()

shape: (5, 5)
┌───────┬────────────────┬────────────────┬────────┬──────────┐
│ state ┆ avg M birthday ┆ avg F birthday ┆ # male ┆ # female │
│ ---   ┆ ---            ┆ ---            ┆ ---    ┆ ---      │
│ str   ┆ f64            ┆ f64            ┆ u32    ┆ u32      │
╞═══════╪════════════════╪════════════════╪════════╪══════════╡
│ GA    ┆ 160.0          ┆ null           ┆ 1      ┆ 0        │
│ OH    ┆ 235.0          ┆ null           ┆ 1      ┆ 0        │
│ CO    ┆ 62.0           ┆ null           ┆ 1      ┆ 0        │
│ KS    ┆ 95.0           ┆ null           ┆ 1      ┆ 0        │
│ TX    ┆ null           ┆ 86.0           ┆ 1      ┆ 1        │
└───────┴────────────────┴────────────────┴────────┴──────────┘

The second time:

shape: (5, 5)
┌───────┬────────────────┬────────────────┬────────┬──────────┐
│ state ┆ avg M birthday ┆ avg F birthday ┆ # male ┆ # female │
│ ---   ┆ ---            ┆ ---            ┆ ---    ┆ ---      │
│ str   ┆ f64            ┆ f64            ┆ u32    ┆ u32      │
╞═══════╪════════════════╪════════════════╪════════╪══════════╡
│ TX    ┆ null           ┆ 86.0           ┆ 1      ┆ 1        │
│ OH    ┆ 235.0          ┆ null           ┆ 1      ┆ 0        │
│ CO    ┆ 62.0           ┆ null           ┆ 1      ┆ 0        │
│ NY    ┆ 218.0          ┆ null           ┆ 1      ┆ 0        │
│ GA    ┆ 160.0          ┆ null           ┆ 1      ┆ 0        │
└───────┴────────────────┴────────────────┴────────┴──────────┘

I guess it may be caused by parallelization? Is this a bug or a feature? How to keep the result consistent?


Solution

  • Use maintain_order=True on group_by.

    maintain_order: Ensure that the order of the groups is consistent with the input data. This is slower than a default group by. Setting this to True blocks the possibility to run on the streaming engine.