Search code examples
scalaapache-sparkrecursionapache-spark-sqlwindow-functions

Run a Cumulative/Iterative Costum Method on a Column in Spark Scala


Hi I am new to Spark/Scala, I have been trying - AKA failing, to create a column in a spark dataframe based on a particular recursive formula:

Here it is in pseudo code.

someDf.col2[0] = 0

for i > 0
someDf.col2[i] = x * someDf.col1[i-1] + (1-x) * someDf.col2[i-1]

To dive into more details, here is my starting point: this dataframe is the result of aggregations both on the level of dates and individual id's.

all further calculations have to happen with respect to that particular id, and have to take into consideration what happened in the previous week.

to illustrate this I have simplified the values to zeros and ones and removed the multiplier x and 1-x, and I also have initialized the col2 to zero.

var someDf = Seq(("2016-01-10 00:00:00.0","385608",0,0), 
         ("2016-01-17 00:00:00.0","385608",0,0),
         ("2016-01-24 00:00:00.0","385608",1,0),
         ("2016-01-31 00:00:00.0","385608",1,0),
         ("2016-02-07 00:00:00.0","385608",1,0),
         ("2016-02-14 00:00:00.0","385608",1,0),
         ("2016-01-17 00:00:00.0","105010",0,0),
         ("2016-01-24 00:00:00.0","105010",1,0),
         ("2016-01-31 00:00:00.0","105010",0,0),
         ("2016-02-07 00:00:00.0","105010",1,0)
        ).toDF("dates", "id", "col1","col2" )

someDf.show()
+--------------------+------+----+----+
|               dates|    id|col1|col2|
+--------------------+------+----+----+
|2016-01-10 00:00:...|385608|   0|   0|
|2016-01-17 00:00:...|385608|   0|   0|
|2016-01-24 00:00:...|385608|   1|   0|
|2016-01-31 00:00:...|385608|   1|   0|
|2016-02-07 00:00:...|385608|   1|   0|
|2016-02-14 00:00:...|385608|   1|   0|
+--------------------+------+----+----+
|2016-01-17 00:00:...|105010|   0|   0|
|2016-01-24 00:00:...|105010|   1|   0|
|2016-01-31 00:00:...|105010|   0|   0|
|2016-02-07 00:00:...|105010|   1|   0|
+--------------------+------+----+----+

what I have tried so far vs what is desired

import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Window

val date_id_window = Window.partitionBy("id").orderBy(asc("dates")) 

someDf.withColumn("col2", lag($"col1",1 ).over(date_id_window) + 
lag($"col2",1 ).over(date_id_window) ).show() 
+--------------------+------+----+----+ / +--------------------+
|               dates|    id|col1|col2| / | what_col2_should_be|
+--------------------+------+----+----+ / +--------------------+
|2016-01-17 00:00:...|105010|   0|null| / |                   0| 
|2016-01-24 00:00:...|105010|   1|   0| / |                   0|
|2016-01-31 00:00:...|105010|   0|   1| / |                   1|
|2016-02-07 00:00:...|105010|   1|   0| / |                   1|
+-------------------------------------+ / +--------------------+
|2016-01-10 00:00:...|385608|   0|null| / |                   0|
|2016-01-17 00:00:...|385608|   0|   0| / |                   0|
|2016-01-24 00:00:...|385608|   1|   0| / |                   0|
|2016-01-31 00:00:...|385608|   1|   1| / |                   1|
|2016-02-07 00:00:...|385608|   1|   1| / |                   2|
|2016-02-14 00:00:...|385608|   1|   1| / |                   3|
+--------------------+------+----+----+ / +--------------------+

Is there a way to do this with Spark dataframe, I have seen multiple cumulative type computations, but never including the same column, I believe the problem is that the newly computed value for row i-1 is not considered, instead the old i-1 is used which is always 0.

Any help would be appreciated.


Solution

  • Dataset should work just fine:

    val x = 0.1
    
    case class Record(dates: String, id: String, col1: Int)
    
    someDf.drop("col2").as[Record].groupByKey(_.id).flatMapGroups((_,  records) => {
      val sorted = records.toSeq.sortBy(_.dates)
      sorted.scanLeft((null: Record, 0.0)){
        case ((_, col2), record) => (record, x * record.col1 + (1 - x) * col2)
      }.tail
    }).select($"_1.*", $"_2".alias("col2"))