The input dataframe it looks like this:
+---+----------+----------+--------+-----+-------------------+
| id|product_id|sales_date|quantity|price| timestampCol|
+---+----------+----------+--------+-----+-------------------+
| 1| 1|2022-12-31| 10| 10.0|2022-12-31 00:00:00|
| 2| 1|2023-01-01| 10| 10.0|2023-01-01 00:00:00|
| 3| 1|2023-01-02| 12| 12.0|2023-01-02 00:00:00|
| 4| 1|2023-01-03| 15| 15.0|2023-01-03 00:00:00|
| 5| 2|2023-01-01| 8| 8.0|2023-01-01 00:00:00|
| 6| 2|2023-01-02| 10| 10.0|2023-01-02 00:00:00|
| 7| 2|2023-01-03| 12| 12.0|2023-01-03 00:00:00|
+---+----------+----------+--------+-----+-------------------+
The task is to calculate moving average for price for 2 days partition by product_id. Moving average frame include the current timestamp and previous timestamp. For example, for (id = 2) avg must be ((10.0 + 10.0)/2), for (id = 3) - ((12.0 + 10.0)/2) and so on.
I tried the following code:
val productWindow = Window
.partitionBy(countriesWithTS("product_id")).orderBy(countriesWithTS("timestampCol"))
.rowsBetween(2, Window.currentRow)
countriesWithTS
.withColumn("moved_avg",
round(avg(countriesWithTS("price")).over(productWindow), 2))
.show()
And it returns the dataframe with null-values for "moved_avg" column:
+---+----------+----------+--------+-----+-------------------+---------+
| id|product_id|sales_date|quantity|price| timestampCol|moved_avg|
+---+----------+----------+--------+-----+-------------------+---------+
| 1| 1|2022-12-31| 10| 10.0|2022-12-31 00:00:00| null|
| 2| 1|2023-01-01| 10| 10.0|2023-01-01 00:00:00| null|
| 3| 1|2023-01-02| 12| 12.0|2023-01-02 00:00:00| null|
| 4| 1|2023-01-03| 15| 15.0|2023-01-03 00:00:00| null|
| 5| 2|2023-01-01| 8| 8.0|2023-01-01 00:00:00| null|
| 6| 2|2023-01-02| 10| 10.0|2023-01-02 00:00:00| null|
| 7| 2|2023-01-03| 12| 12.0|2023-01-03 00:00:00| null|
+---+----------+----------+--------+-----+-------------------+---------+
The problem somewhere in "rowsBetween" argument. When I comment this part of code, average is calculated successfully (however it's does not, that I need, as it's not moving average).
Additionally notice, that "price" type is represented in scheme as "StructField("price", FloatType, nullable = true)".
What am I doing wrong?
If you want to include the current row and the row before then you need:
.rowsBetween(-1, Window.currentRow)
Full example:
val productWindow = Window.partitionBy(countriesWithTS("product_id")).orderBy(countriesWithTS("timestampCol")).rowsBetween(-1, Window.currentRow)
countriesWithTS.withColumn("moved_avg", round(avg(countriesWithTS("price")).over(productWindow), 2)).show()
Results:
+---+----------+----------+--------+-----+-------------------+---------+
| id|product_id|sales_date|quantity|price| timestampCol|moved_avg|
+---+----------+----------+--------+-----+-------------------+---------+
| 1| 1|2022-12-31| 10| 10.0|2022-12-31 00:00:00| 10.0|
| 2| 1|2023-01-01| 10| 10.0|2023-01-01 00:00:00| 10.0|
| 3| 1|2023-01-02| 12| 12.0|2023-01-02 00:00:00| 11.0|
| 4| 1|2023-01-03| 15| 15.0|2023-01-03 00:00:00| 13.5|
| 5| 2|2023-01-01| 8| 8.0|2023-01-01 00:00:00| 8.0|
| 6| 2|2023-01-02| 10| 10.0|2023-01-02 00:00:00| 9.0|
| 7| 2|2023-01-03| 12| 12.0|2023-01-03 00:00:00| 11.0|
+---+----------+----------+--------+-----+-------------------+---------+