Search code examples
scalaapache-sparkwindow-functionsspark-window-function

Spark (Scala): Moving average with Window function


The input dataframe it looks like this:

+---+----------+----------+--------+-----+-------------------+
| id|product_id|sales_date|quantity|price|       timestampCol|
+---+----------+----------+--------+-----+-------------------+                       
|  1|         1|2022-12-31|      10| 10.0|2022-12-31 00:00:00|
|  2|         1|2023-01-01|      10| 10.0|2023-01-01 00:00:00|
|  3|         1|2023-01-02|      12| 12.0|2023-01-02 00:00:00|
|  4|         1|2023-01-03|      15| 15.0|2023-01-03 00:00:00|
|  5|         2|2023-01-01|       8|  8.0|2023-01-01 00:00:00|
|  6|         2|2023-01-02|      10| 10.0|2023-01-02 00:00:00|
|  7|         2|2023-01-03|      12| 12.0|2023-01-03 00:00:00|
+---+----------+----------+--------+-----+-------------------+

The task is to calculate moving average for price for 2 days partition by product_id. Moving average frame include the current timestamp and previous timestamp. For example, for (id = 2) avg must be ((10.0 + 10.0)/2), for (id = 3) - ((12.0 + 10.0)/2) and so on.

I tried the following code:

val productWindow = Window
  .partitionBy(countriesWithTS("product_id")).orderBy(countriesWithTS("timestampCol"))
  .rowsBetween(2, Window.currentRow)

countriesWithTS
  .withColumn("moved_avg",
    round(avg(countriesWithTS("price")).over(productWindow), 2))
  .show()

And it returns the dataframe with null-values for "moved_avg" column:

 +---+----------+----------+--------+-----+-------------------+---------+
 | id|product_id|sales_date|quantity|price|       timestampCol|moved_avg|
 +---+----------+----------+--------+-----+-------------------+---------+
 |  1|         1|2022-12-31|      10| 10.0|2022-12-31 00:00:00|     null|
 |  2|         1|2023-01-01|      10| 10.0|2023-01-01 00:00:00|     null|
 |  3|         1|2023-01-02|      12| 12.0|2023-01-02 00:00:00|     null|
 |  4|         1|2023-01-03|      15| 15.0|2023-01-03 00:00:00|     null|
 |  5|         2|2023-01-01|       8|  8.0|2023-01-01 00:00:00|     null|
 |  6|         2|2023-01-02|      10| 10.0|2023-01-02 00:00:00|     null|
 |  7|         2|2023-01-03|      12| 12.0|2023-01-03 00:00:00|     null|
 +---+----------+----------+--------+-----+-------------------+---------+

The problem somewhere in "rowsBetween" argument. When I comment this part of code, average is calculated successfully (however it's does not, that I need, as it's not moving average).

Additionally notice, that "price" type is represented in scheme as "StructField("price", FloatType, nullable = true)".

What am I doing wrong?


Solution

  • If you want to include the current row and the row before then you need:

    .rowsBetween(-1, Window.currentRow)
    

    Full example:

    val productWindow = Window.partitionBy(countriesWithTS("product_id")).orderBy(countriesWithTS("timestampCol")).rowsBetween(-1, Window.currentRow)
    countriesWithTS.withColumn("moved_avg", round(avg(countriesWithTS("price")).over(productWindow), 2)).show()
    

    Results:

    +---+----------+----------+--------+-----+-------------------+---------+
    | id|product_id|sales_date|quantity|price|       timestampCol|moved_avg|
    +---+----------+----------+--------+-----+-------------------+---------+
    |  1|         1|2022-12-31|      10| 10.0|2022-12-31 00:00:00|     10.0|
    |  2|         1|2023-01-01|      10| 10.0|2023-01-01 00:00:00|     10.0|
    |  3|         1|2023-01-02|      12| 12.0|2023-01-02 00:00:00|     11.0|
    |  4|         1|2023-01-03|      15| 15.0|2023-01-03 00:00:00|     13.5|
    |  5|         2|2023-01-01|       8|  8.0|2023-01-01 00:00:00|      8.0|
    |  6|         2|2023-01-02|      10| 10.0|2023-01-02 00:00:00|      9.0|
    |  7|         2|2023-01-03|      12| 12.0|2023-01-03 00:00:00|     11.0|
    +---+----------+----------+--------+-----+-------------------+---------+