Search code examples
apache-sparkapache-spark-sqlwindow-functions

SparkSQL Windows: Creating Frame Based On Array Column


I am looking to use SparkSQL's window function, but with a custom condition on the frame specification.

The dataframe being operated on is as follows:

+--------------------+--------------------+--------------------+-----+
|              userid|           elementid|       prerequisites|score|
+--------------------+--------------------+--------------------+-----+
|a                   |1                   |[]                  |  1  |
|a                   |2                   |[]                  |  1  |
|a                   |3                   |[]                  |  1  |
|b                   |1                   |[]                  |  1  |
|a                   |4                   |[1, 2]              |  1  |
+--------------------+--------------------+--------------------+-----+

Every element in the prerequisites column is a value in another row's elementid column.

I would like to create a window where I partition by userid, and then grab all preceding rows where elementid is contained in the present row's prerequisites column.

Once I attain this window, I want to perform a sum on the score column.

Desired output for the above example:

+--------------------+--------------------+--------------------+-----+
|              userid|           elementid|       prerequisites|sum  |
+--------------------+--------------------+--------------------+-----+
|a                   |1                   |[]                  |  0  |
|a                   |2                   |[]                  |  0  |
|a                   |3                   |[]                  |  0  |
|b                   |1                   |[]                  |  0  |
|a                   |4                   |[1, 2]              |  2  |
+--------------------+--------------------+--------------------+-----+

Notice how because user a is the only user with the prerequisites of its element preceding it, its the only one with > 0 sum.

The closest question I saw was this question, which utilises collect_list.

However, that doesn't construct a window so much as collect a potential list of IDs. Anyone have any ideas on how to construct the aforementioned window?


Solution

  • scala> import org.apache.spark.sql.expressions.{Window,UserDefinedFunction}
    
    scala> df.show()
    +------+---------+-------------+-----+
    |userid|elementid|prerequisites|score|
    +------+---------+-------------+-----+
    |     a|        1|           []|    1|
    |     a|        2|           []|    1|
    |     a|        3|           []|    1|
    |     b|        1|           []|    1|
    |     a|        4|       [1, 2]|    1|
    +------+---------+-------------+-----+
    
    scala> df.printSchema
    root
     |-- userid: string (nullable = true)
     |-- elementid: string (nullable = true)
     |-- prerequisites: array (nullable = true)
     |    |-- element: string (containsNull = true)
     |-- score: string (nullable = true)
    
    scala> val W = Window.partitionBy("userid")
    
    scala> val df1 = df.withColumn("elementidList", collect_set(col("elementid")).over(W))
                       .withColumn("elementidScoreMap", map_from_arrays(col("elementidList"), collect_list(col("score").cast("long")).over(W)))
                       .withColumn("common", array_intersect(col("prerequisites"), col("elementidList")))
                       .drop("elementidList", "score") 
    
    scala> def getSumUDF:UserDefinedFunction = udf((Score:Map[String,Long], Id:String) => {
         | var out:Long =  0
         | Id.split(",").foreach{ x => out = Score(x.toString) + out}
         | out})
    
    scala> df1.withColumn("sum", when(size(col("common")) =!= 0  ,getSumUDF(col("elementidScoreMap"), concat_ws(",",col("prerequisites")))).otherwise(lit(0)))
              .drop("elementidScoreMap", "common")
              .show()
    +------+---------+-------------+---+
    |userid|elementid|prerequisites|sum|
    +------+---------+-------------+---+
    |     b|        1|           []|  0|
    |     a|        1|           []|  0|
    |     a|        2|           []|  0|
    |     a|        3|           []|  0|
    |     a|        4|       [1, 2]|  2|
    +------+---------+-------------+---+