Search code examples
apache-sparkapache-spark-sqluser-defined-functionsaggregationwindow-functions

spark aggregate set of events per key including their change timestamps


For a dataframe of:

+----+--------+-------------------+----+
|user|      dt|         time_value|item|
+----+--------+-------------------+----+
| id1|20200101|2020-01-01 00:00:00|   A|
| id1|20200101|2020-01-01 10:00:00|   B|
| id1|20200101|2020-01-01 09:00:00|   A|
| id1|20200101|2020-01-01 11:00:00|   B|
+----+--------+-------------------+----+

I want to capture all the unique items i.e. collect_set, but retain its own time_value

import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions.unix_timestamp
import org.apache.spark.sql.functions.collect_set
import org.apache.spark.sql.types.TimestampType
val timeFormat = "yyyy-MM-dd HH:mm"
val dx = Seq(("id1", "20200101", "2020-01-01 00:00", "A"), ("id1", "20200101","2020-01-01 10:00", "B"), ("id1", "20200101","2020-01-01 9:00", "A"), ("id1", "20200101","2020-01-01 11:00", "B")).toDF("user", "dt","time_value", "item").withColumn("time_value", unix_timestamp(col("time_value"), timeFormat).cast(TimestampType))
dx.show

A

dx.groupBy("user", "dt").agg(collect_set("item")).show
+----+--------+-----------------+                                               
|user|      dt|collect_set(item)|
+----+--------+-----------------+
| id1|20200101|           [B, A]|
+----+--------+-----------------+

does not retain the time_value information when the signal switched from A to B. How can I keep the time value information for each set in the item?

Would it be possible to have the collect_set within a window function to achieve the desired result? Currently, I can only think of:

  1. use a window function to determine pairs of events
  2. filter to change events
  3. aggregate

which needs to shuffle multiple times. Alternatively, a UDF would be possible (collect_list(sort_array(struct(time_value, item)))) but that also seems rather clumsy.

Is there a better way?


Solution

  • I would indeed use window-functions to isolate the change-points, I think there are no alternatives:

    val win = Window.partitionBy($"user",$"dt").orderBy($"time_value")
    
    dx
    .orderBy($"time_value")
    .withColumn("item_change_post",coalesce((lag($"item",1).over(win)=!=$"item"),lit(false)))
    .withColumn("item_change_pre",lead($"item_change_post",1).over(win))
    .where($"item_change_pre" or $"item_change_post")
    .show()
    
    +----+--------+-------------------+----+----------------+---------------+
    |user|      dt|         time_value|item|item_change_post|item_change_pre|
    +----+--------+-------------------+----+----------------+---------------+
    | id1|20200101|2020-01-01 09:00:00|   A|           false|           true|
    | id1|20200101|2020-01-01 10:00:00|   B|            true|          false|
    +----+--------+-------------------+----+----------------+---------------+
    

    then use something like groupBy($"user",$"dt").agg(collect_list(struct($"time_value",$"item")))

    I don't think that multiple shuffles occur, because you always partition/group by the same keys.

    You can try to make it more efficient by aggregating your initial dataframe to the min/max time_value for each item, then do the same as above.