Search code examples
pythonpython-polars

Polars - compute on all other values in a window except this one


For each row, I'm trying to compute the standard deviation for the other values in a group excluding the row's value. A way to think about it is "what would the standard deviation for the group be if this row's value was removed". An example may be easier to parse:

df = pl.DataFrame(
    {
        "answer": ["yes","yes","yes","yes","maybe","maybe","maybe"],
        "value": [5,10,7,8,6,9,10],
    }
)
┌────────┬───────┐
│ answer ┆ value │
│ ---    ┆ ---   │
│ str    ┆ i64   │
╞════════╪═══════╡
│ yes    ┆ 5     │
│ yes    ┆ 10    │
│ yes    ┆ 7     │
│ yes    ┆ 8     │
│ maybe  ┆ 6     │
│ maybe  ┆ 9     │
│ maybe  ┆ 10    │
└────────┴───────┘

I would want to add a column that would have the first row be std([10,7,8]) = 1.527525

I tried to hack something together and ended up with code that is horrible to read and also has a bug that I don't know how to work around:

df.with_columns(
    (
        (pl.col("value").sum().over(pl.col("answer")) - pl.col("value"))
        / (pl.col("value").count().over(pl.col("answer")) - 1)
    ).alias("average_other")
).with_columns(
    (
        (
            (
                (pl.col("value") - pl.col("average_other")).pow(2).sum().over(pl.col("answer"))
                - (pl.col("value") - pl.col("average_other")).pow(2)
            )
            / (pl.col("value").count().over(pl.col("answer")) - 1)
        ).sqrt()
    ).alias("std_dev_other")
)

I'm not sure I would recommend parsing that, but I'll point out at least one thing that is wrong:

pl.col("value") - pl.col("average_other")).pow(2).sum().over(pl.col("answer"))

I want to be comparing "value" in each row to "average_other" from this row then squaring and summing over the window but instead I am comparing "value" in each row to "average_other" in each row.

My main question is the "what is the best way to get the standard deviation while leaving out this value?" part. But I would also be interested if there is a way to do the comparison that I'm doing wrong above. Third would be tips on how to write this in way that is easy to understand what is going on.


Solution

  • I came up with something similiar to @DeanMacGregor's answer:

    df = (
        df.with_row_index()
        .join(df.with_row_index(), on="answer")
        .filter(pl.col("index") != pl.col("index_right"))
        .group_by("answer", "index_right", maintain_order=True).agg(
            pl.col("value_right").first().alias("value"),
            pl.col("value").std().alias("stdev"),
        )
        .drop("index_right")
    )
    

    .join df with row index on itself and remove the rows where the two row indexes are identical. Then group by answer and row_index_right and (1) pick the first group item out of value_right and (2) calculate the standard deviation over the value group.

    Result for

    df = pl.DataFrame({
        "answer": ["yes", "yes", "yes", "yes", "yes", "maybe", "maybe", "maybe", "maybe"],
        "value": [5, 10, 7, 8, 4, 6, 9, 10, 4],
    })
    

    is

    ┌────────┬───────┬──────────┐
    │ answer ┆ value ┆ stdev    │
    │ ---    ┆ ---   ┆ ---      │
    │ str    ┆ i64   ┆ f64      │
    ╞════════╪═══════╪══════════╡
    │ yes    ┆ 5     ┆ 2.5      │
    │ yes    ┆ 10    ┆ 1.825742 │
    │ yes    ┆ 7     ┆ 2.753785 │
    │ yes    ┆ 8     ┆ 2.645751 │
    │ yes    ┆ 4     ┆ 2.081666 │
    │ maybe  ┆ 6     ┆ 3.21455  │
    │ maybe  ┆ 9     ┆ 3.05505  │
    │ maybe  ┆ 10    ┆ 2.516611 │
    │ maybe  ┆ 4     ┆ 2.081666 │
    └────────┴───────┴──────────┘