Search code examples
apache-sparkapache-spark-sqlpyspark

Rolling correlation and average (last 3) Per Group in PySpark


I have a dataframe like this

data = [(("ID1", 1, 5)), (("ID1", 2, 6)), (("ID1", 3, 7)),
    (("ID1", 4, 4)), (("ID1", 5, 2)), (("ID1", 6, 2)),
    (("ID2", 1, 4)), (("ID2", 2, 6)), (("ID2", 3, 1)), (("ID2", 4, 1)), (("ID2", 5, 4))]
df = spark.createDataFrame(data, ["ID", "colA", "colB"])
df.show()

+---+----+----+
| ID|colA|colB|
+---+----+----+
|ID1|   1|   5|
|ID1|   2|   6|
|ID1|   3|   7|
|ID1|   4|   4|
|ID1|   5|   2|
|ID1|   6|   2|
|ID2|   1|   4|
|ID2|   2|   6|
|ID2|   3|   1|
|ID2|   4|   1|
|ID2|   5|   4|
+---+----+----+

I want to calculate last 3 correlation and average, per group, of last 3 elements.

Hence for ID1, for first element (5) - Average = 5, corr = 0
for ID1, for first 2 element (5, 6) - Average = 5.5, corr with colA = 1
for ID1, for first 3 element (5, 6, 7) - Average = 6, corr with colA = 1
for ID1, for elements (6, 7, 4) - Average = 5.66, corr with colA = -0.65


Expected output is like this

    +---+----+----+----------+---------+
    | ID|colA|colB|corr_last3|avg_last3|
    +---+----+----+----------+---------+
    |ID1|   1|   5|         0|        5|
    |ID1|   2|   6|         1|      5.5|
    |ID1|   3|   7|         1|        6|
    |ID1|   4|   4|     -0.65|     5.66|
    |ID1|   5|   2|     -0.99|     4.33|
    |ID1|   6|   2|     -0.86|     2.66|
    |ID2|   1|   4|         0|        4|
    |ID2|   2|   6|         1|        5|
    |ID2|   3|   1|     -0.59|     3.66|
    |ID2|   4|   1|     -0.86|     2.66|
    |ID2|   5|   4|      0.86|        2|
    +---+----+----+----------+---------+


Solution

  • You can do it with built-in functions avg and corr, here the scala solution :

    df
      .withColumn("indices",row_number().over(Window.partitionBy($"ID").orderBy($"colA")))
      .withColumn("corr_last3", when($"indices">1,corr($"indices",$"colB").over(Window.partitionBy($"ID").orderBy($"colA").rowsBetween(-2L,Window.currentRow))).otherwise(0.0))
      .withColumn("avg_last3", avg($"colB").over(Window.partitionBy($"ID").orderBy($"colA").rowsBetween(-2L,Window.currentRow)))
      .drop($"indices")
      .orderBy($"ID",$"colA")
      .show() 
    

    gives:

    +---+----+----+-------------------+------------------+
    | ID|colA|colB|         corr_last3|         avg_last3|
    +---+----+----+-------------------+------------------+
    |ID1|   1|   5|                0.0|               5.0|
    |ID1|   2|   6|                1.0|               5.5|
    |ID1|   3|   7|                1.0|               6.0|
    |ID1|   4|   4|-0.6546536707079772| 5.666666666666667|
    |ID1|   5|   2|-0.9933992677987828| 4.333333333333333|
    |ID1|   6|   2|-0.8660254037844386|2.6666666666666665|
    |ID2|   1|   4|                0.0|               4.0|
    |ID2|   2|   6|                1.0|               5.0|
    |ID2|   3|   1|-0.5960395606792697|3.6666666666666665|
    |ID2|   4|   1|-0.8660254037844387|2.6666666666666665|
    |ID2|   5|   4| 0.8660254037844387|               2.0|
    +---+----+----+-------------------+------------------+