Search code examples
pythonpython-polars

Polars: how can I calculate lagged correlations between days?


I have a polars dataframe as below:

import polars as pl

df = pl.DataFrame(
    {
        "class": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        "day": [1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4],
        "id": [1, 2, 3, 2, 3, 4, 1, 2, 5, 2, 1, 3, 4],
        "value": [1, 2, 2, 3, 5, 2, 1, 2, 7, 3, 5, 3, 4],
    }
)

The result I want to have is:

  • Group by "class" (although there is just one in this example, assume there are many of them).
  • Calculate all pairwise correlations for all possible day pairs, e.g., between "day" - 1 and "day" - 2, "day" - 2 and "day" - 4, etc.
  • The two series between one particular "day" pair are taken from "value" and matched by "id" and the correlation is calculated by only considering the intersections, for example, the correlation between "day" - 1 and "day" - 4 is the correlation between [1, 2, 2] and [5, 3, 3].

I may want to structure the results as such:

class cor_day_1_2 cor_day_1_3 cor_day_1_4 cor_day_2_3 cor_day_2_4 cor_day_3_4
1     -           -           -           -           -           -
.
.
.

I have tried using df.pivot to start with but get stuck by a few reasons:

  • Need to do transpose (which could be expansive)
  • Otherwise, compute row-wise correlation (don't think it is supported out of the box)

Thanks a lot for your potential help.


Solution

  • Here's an attempt to started: using .join() to group by class, id then filtering out the duplicates.

    (df.join(df, on=["class", "id"])
       .filter(pl.col("day") < pl.col("day_right"))
       .group_by("class", "day", "day_right")
       .all()
    )
    
    shape: (6, 6)
    ┌───────┬─────┬───────────┬───────────┬───────────┬─────────────┐
    │ class | day | day_right | id        | value     | value_right │
    │ ---   | --- | ---       | ---       | ---       | ---         │
    │ i64   | i64 | i64       | list[i64] | list[i64] | list[i64]   │
    ╞═══════╪═════╪═══════════╪═══════════╪═══════════╪═════════════╡
    │ 1     | 3   | 4         | [2, 1]    | [2, 1]    | [3, 5]      │
    │ 1     | 1   | 2         | [2, 3]    | [2, 2]    | [3, 5]      │
    │ 1     | 2   | 3         | [2]       | [3]       | [2]         │
    │ 1     | 1   | 4         | [2, 1, 3] | [2, 1, 2] | [3, 5, 3]   │
    │ 1     | 2   | 4         | [2, 3, 4] | [3, 5, 2] | [3, 3, 4]   │
    │ 1     | 1   | 3         | [1, 2]    | [1, 2]    | [1, 2]      │
    └───────┴─────┴───────────┴───────────┴───────────┴─────────────┘
    

    Update: .join_where() has since been added, which should be more efficient.

    df.join_where(df, 
       pl.col.id == pl.col.id_right,
       pl.col("class") == pl.col.class_right,
       pl.col.day < pl.col.day_right
    )