I have a dataframe like this
data = [(("ID1", 1, 5)), (("ID1", 2, 6)), (("ID1", 3, 7)),
(("ID1", 4, 4)), (("ID1", 5, 2)), (("ID1", 6, 2)),
(("ID2", 1, 4)), (("ID2", 2, 6)), (("ID2", 3, 1)), (("ID2", 4, 1)), (("ID2", 5, 4))]
df = spark.createDataFrame(data, ["ID", "colA", "colB"])
df.show()
+---+----+----+
| ID|colA|colB|
+---+----+----+
|ID1| 1| 5|
|ID1| 2| 6|
|ID1| 3| 7|
|ID1| 4| 4|
|ID1| 5| 2|
|ID1| 6| 2|
|ID2| 1| 4|
|ID2| 2| 6|
|ID2| 3| 1|
|ID2| 4| 1|
|ID2| 5| 4|
+---+----+----+
I want to calculate last 3 correlation and average, per group, of last 3 elements.
Hence for ID1, for first element (5) - Average = 5, corr = 0
for ID1, for first 2 element (5, 6) - Average = 5.5, corr with colA = 1
for ID1, for first 3 element (5, 6, 7) - Average = 6, corr with colA = 1
for ID1, for elements (6, 7, 4) - Average = 5.66, corr with colA = -0.65
Expected output is like this
+---+----+----+----------+---------+
| ID|colA|colB|corr_last3|avg_last3|
+---+----+----+----------+---------+
|ID1| 1| 5| 0| 5|
|ID1| 2| 6| 1| 5.5|
|ID1| 3| 7| 1| 6|
|ID1| 4| 4| -0.65| 5.66|
|ID1| 5| 2| -0.99| 4.33|
|ID1| 6| 2| -0.86| 2.66|
|ID2| 1| 4| 0| 4|
|ID2| 2| 6| 1| 5|
|ID2| 3| 1| -0.59| 3.66|
|ID2| 4| 1| -0.86| 2.66|
|ID2| 5| 4| 0.86| 2|
+---+----+----+----------+---------+
You can do it with built-in functions avg
and corr
, here the scala solution :
df
.withColumn("indices",row_number().over(Window.partitionBy($"ID").orderBy($"colA")))
.withColumn("corr_last3", when($"indices">1,corr($"indices",$"colB").over(Window.partitionBy($"ID").orderBy($"colA").rowsBetween(-2L,Window.currentRow))).otherwise(0.0))
.withColumn("avg_last3", avg($"colB").over(Window.partitionBy($"ID").orderBy($"colA").rowsBetween(-2L,Window.currentRow)))
.drop($"indices")
.orderBy($"ID",$"colA")
.show()
gives:
+---+----+----+-------------------+------------------+
| ID|colA|colB| corr_last3| avg_last3|
+---+----+----+-------------------+------------------+
|ID1| 1| 5| 0.0| 5.0|
|ID1| 2| 6| 1.0| 5.5|
|ID1| 3| 7| 1.0| 6.0|
|ID1| 4| 4|-0.6546536707079772| 5.666666666666667|
|ID1| 5| 2|-0.9933992677987828| 4.333333333333333|
|ID1| 6| 2|-0.8660254037844386|2.6666666666666665|
|ID2| 1| 4| 0.0| 4.0|
|ID2| 2| 6| 1.0| 5.0|
|ID2| 3| 1|-0.5960395606792697|3.6666666666666665|
|ID2| 4| 1|-0.8660254037844387|2.6666666666666665|
|ID2| 5| 4| 0.8660254037844387| 2.0|
+---+----+----+-------------------+------------------+