Search code examples
pythonpython-polars

How to filter on uniqueness by condition


Imagine I have a dataset like:

data = {
    "a": [1, 4, 2, 4, 7, 4],
    "b": [4, 2, 3, 3, 0, 2],
    "c": ["a", "b", "c", "d", "e", "f"],
}

and I want to keep only the rows for which a + b is uniquely described by a single combination of a and b. I managed to hack this:

df = (
    pl.DataFrame(data)
    .with_columns(sum_ab=pl.col("a") + pl.col("b"))
    .group_by("sum_ab")
    .agg(pl.col("a"), pl.col("b"), pl.col("c"))
    .filter(
        (pl.col("a").list.unique().list.len() == 1)
        & (pl.col("b").list.unique().list.len() == 1)
    )
    .explode(["a", "b", "c"])
    .select("a", "b", "c")
)

"""
shape: (2, 3)
┌─────┬─────┬─────┐
│ a   ┆ b   ┆ c   │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ str │
╞═════╪═════╪═════╡
│ 4   ┆ 2   ┆ b   │
│ 4   ┆ 2   ┆ f   │
└─────┴─────┴─────┘
"""

Can someone suggest a better way to achieve the same? I struggled a bit to figure this logic out, so I imagine there is a more direct/elegant way of getting the same result.


Solution

    • .struct() to combine a and b into one column so we can check uniqueness.
    • n_unique() to check uniqueness.
    • over() to limit the calculation to be within a + b.
    df.filter(
        pl.struct("a","b").n_unique().over(pl.col.a + pl.col.b) == 1
    )
    
    ┌─────┬─────┬─────┐
    │ a   ┆ b   ┆ c   │
    │ --- ┆ --- ┆ --- │
    │ i64 ┆ i64 ┆ str │
    ╞═════╪═════╪═════╡
    │ 4   ┆ 2   ┆ b   │
    │ 4   ┆ 2   ┆ f   │
    └─────┴─────┴─────┘
    

    If you would need to extend it to larger number of columns then you could use sum_horizontal() to make it more generic:

    columns = ["a","b"]
    
    df.filter(
        pl.struct(columns).n_unique().over(pl.sum_horizontal(columns)) == 1
    )