Search code examples
apache-sparkpysparkgroup-byapache-spark-sqlrolling-computation

Pyspark: aggregate mode (most frequent) value in a rolling window


I have a dataframe such as follows. I would like to group by device and order by start_time within each group. Then, for each row in the group, get the most frequently occurring station from a window of 3 rows before it (including itself).

columns = ['device', 'start_time', 'station']
data = [("Python", 1, "station_1"), ("Python", 2, "station_2"), ("Python", 3, "station_1"), ("Python", 4, "station_2"), ("Python", 5, "station_2"), ("Python", 6, None)]


test_df = spark.createDataFrame(data).toDF(*columns)
rolling_w = Window.partitionBy('device').orderBy('start_time').rowsBetween(-2, 0)

Desired output:

+------+----------+---------+--------------------+
|device|start_time|  station|rolling_mode_station|
+------+----------+---------+--------------------+
|Python|         1|station_1|           station_1|
|Python|         2|station_2|           station_2|
|Python|         3|station_1|           station_1|
|Python|         4|station_2|           station_2|
|Python|         5|station_2|           station_2|
|Python|         6|     null|           station_2|
+------+----------+---------+--------------------+

Since Pyspark does not have a mode() function, I know how to get the most frequent value in a static groupby as shown here, but I don't know how to adapt it to a rolling window.


Solution

  • You can use collect_list function to get the stations from last 3 rows using the defined window, then for each resulting array calculate the most frequent element.

    To get the most frequent element on the array, you can explode it then group by and count as in linked post your already saw or use some UDF like this:

    import pyspark.sql.functions as F
    
    test_df.withColumn(
        "rolling_mode_station",
        F.collect_list("station").over(rolling_w)
    ).withColumn(
        "rolling_mode_station",
        F.udf(lambda x: max(set(x), key=x.count))(F.col("rolling_mode_station"))
    ).show()
    
    #+------+----------+---------+--------------------+
    #|device|start_time|  station|rolling_mode_station|
    #+------+----------+---------+--------------------+
    #|Python|         1|station_1|           station_1|
    #|Python|         2|station_2|           station_1|
    #|Python|         3|station_1|           station_1|
    #|Python|         4|station_2|           station_2|
    #|Python|         5|station_2|           station_2|
    #|Python|         6|     null|           station_2|
    #+------+----------+---------+--------------------+