Search code examples
pythondataframepyspark

Calculate rolling counts from two different time series columns in pyspark


I have a pyspark dataframe that contains two columns. Arrival and departure. The idea is to calculate the number of departure events that fall within a specified window calculated based on arrival time. So for example, if an item arrived on 23:00 then I would like to take a window of -12 hours [11:00, 23:00] and calculate the number of items that left within that time interval.

Here is my code to create it. But as you see, it doesn't work since I can do hour aggregation either on arrival_time column or dept_time column.

from pyspark.sql.window import Window
spark = SparkSession.builder.appName("rolling_window_example").getOrCreate()

# Sample data
data = [
    ("2024-05-10 02:00:00", "2024-05-10 21:30:00", 0, 1),
    ("2024-05-12 14:10:00", "2024-05-13 02:00:00", 1, 1),
    ("2024-05-05 03:00:00", "2024-05-14 03:30:00", 2, 2),
    ("2024-05-14 01:32:00", "2024-05-14 23:30:00", 0, 2),
    ("2024-05-14 01:00:00", "2024-05-15 01:30:00", 0, 1)
]

columns = ["dept_time", "arrival_time", "ground_truth_12", "ground_truth_24"]

# Create DataFrame
df = spark.createDataFrame(data, columns)
df = df.withColumn("dept_timestamp", col("dept_time").cast("timestamp"))
df = df.withColumn("arrival_timestamp", col("arrival_time").cast("timestamp"))
df = df.withColumn("dept_time", col("dept_time").cast("timestamp").cast("long"))
df = df.withColumn("arrival_time", col("arrival_time").cast("timestamp").cast("long"))

# Calculate windows wrt arrival time
window_12 = Window.partitionBy().orderBy("arrival_time").rangeBetween(-12 * 3600, Window.currentRow)
window_24 = Window.partitionBy().orderBy("arrival_time").rangeBetween(-24 * 3600, Window.currentRow)

df_rolling = df \
    .orderBy("dept_time") \
    .withColumn("t_12_count", f.count("dept_time").over(window_12)) \
    .withColumn("t_24_count", f.count("dept_time").over(window_24))
                
# Show results
display(df_rolling)

The output is not as expected:

+----------+------------+---------------+---------------+-------------------+-------------------+----------+----------+
| dept_time|arrival_time|ground_truth_12|ground_truth_24|     dept_timestamp|  arrival_timestamp|t_12_count|t_24_count|
+----------+------------+---------------+---------------+-------------------+-------------------+----------+----------+
|1715306400|  1715376600|              0|              1|2024-05-10 02:00:00|2024-05-10 21:30:00|         1|         1|
|1715523000|  1715565600|              1|              1|2024-05-12 14:10:00|2024-05-13 02:00:00|         1|         1|
|1714878000|  1715657400|              2|              2|2024-05-05 03:00:00|2024-05-14 03:30:00|         1|         1|
|1715650320|  1715729400|              0|              2|2024-05-14 01:32:00|2024-05-14 23:30:00|         1|         2|
|1715648400|  1715736600|              0|              1|2024-05-14 01:00:00|2024-05-15 01:30:00|         2|         3|
+----------+------------+---------------+---------------+-------------------+-------------------+----------+----------+

The expected output can be seen in the columns: ground_truth_12 and ground_truth_24 for 12 hours and 24 hours window respectively.


Solution

  • Here's a clever way to figure out all departures before current rows arrival.

    Label the corresponding times with arrival flag i.e. "A" or departure "D"

    Now union these two dataframes.

    Order these dataframes by time irrespective of label.

    Create a window specification which will count all the "D" rows that occur within the window of -12 hours to 0s (current row/time).

    Do the above operation only for "A" rows since that's what we care about in the final result.

    Similarly for -24 hours window.

    Important Note :

    There is error in the second row ground truth which I have corrected below.

    Following is a working example.

    from pyspark.sql import SparkSession
    from pyspark.sql.functions import *
    from pyspark.sql.window import Window
    from pyspark.sql.types import *
    
    spark = SparkSession.builder.appName("SelfJoinExample").getOrCreate()
    
    ## There is error in the second row grount truth which I have corrected below.
    
    data = [
        ("2024-05-10 02:00:00", "2024-05-10 21:30:00", 0, 1),
        ("2024-05-12 12:00:00", "2024-05-13 02:00:00", 0, 1),
        ("2024-05-05 03:00:00", "2024-05-14 03:30:00", 2, 2),
        ("2024-05-14 01:32:00", "2024-05-14 23:30:00", 0, 2),
        ("2024-05-14 01:00:00", "2024-05-15 01:30:00", 0, 1)
    ]
    
    columns = ["departure_time", "arrival_time", "ground_truth_12", "ground_truth_24"]
    
    # Create DataFrame
    df = spark.createDataFrame(data, columns)
    df = df.withColumn("departure_timestamp", col("departure_time").cast("timestamp")).drop("departure_time")
    df = df.withColumn("arrival_timestamp", col("arrival_time").cast("timestamp")).drop("arrival_time")
    df = df.withColumn("mono_id", monotonically_increasing_id())
    df = df.withColumn("arrival_label", array(col("arrival_timestamp"), lit("A"), col("mono_id")))
    df = df.withColumn("dept_label", array(col("departure_timestamp"), lit("D"), col("mono_id")))
    
    df_arrival = df.select(col("arrival_label").alias("common_name"))
    df_dept = df.select(col("dept_label").alias("common_name"))
    
    df_union = df_arrival.union(df_dept)
    
    df_union = df_union.orderBy(col("common_name")[0])
    df_union.show(n=1000, truncate=False)
    
    
    df_mid = df_union.withColumn("timestamp", to_timestamp(df_union["common_name"][0]))
    df_mid = df_mid.withColumn("long_ts", col("timestamp").cast("long"))
    df_mid = df_mid.withColumn("type", df_union["common_name"][1])
    df_mid = df_mid.withColumn("value", df_union["common_name"][2].cast(LongType()))
    df_mid = df_mid.drop("common_name")
    df_mid = df_mid.withColumn("dept_flag", when(col("type") == "D", 1).otherwise(0))
    
    df_mid.show(truncate=False)
    
    
    windowSpec12 = Window.orderBy("long_ts").rangeBetween(-12 * 3600, 0)
    windowSpec24 = Window.orderBy("long_ts").rangeBetween(-24 * 3600, 0)
    
    
    
    df_int = df_mid.withColumn("calc_t12",  when(col("type") == "A", sum("dept_flag").over(windowSpec12)).otherwise(None))
    df_int = df_int.withColumn("calc_t24",  when(col("type") == "A", sum("dept_flag").over(windowSpec24)).otherwise(None))
    
    df_int.show(n=1000, truncate=False)
    
    df_crosscheck = df.join(df_int, on=[col("type") == "A", col("mono_id") == col("value")], how="inner")
    
    print("Final Result")
    df_crosscheck.select("ground_truth_12", "ground_truth_24", "calc_t12", "calc_t24").show(n=1000, truncate=False)
    

    Final cross check dataframe :

    +---------------+---------------+--------+--------+
    |ground_truth_12|ground_truth_24|calc_t12|calc_t24|
    +---------------+---------------+--------+--------+
    |0              |1              |0       |1       |
    |0              |1              |0       |1       |
    |2              |2              |2       |2       |
    |0              |2              |0       |2       |
    |0              |1              |0       |1       |
    +---------------+---------------+--------+--------+
    

    Full Output Below :

    +--------------------------------------+
    |common_name                           |
    +--------------------------------------+
    |[2024-05-05 03:00:00, D, 94489280512] |
    |[2024-05-10 02:00:00, D, 25769803776] |
    |[2024-05-10 21:30:00, A, 25769803776] |
    |[2024-05-12 12:00:00, D, 60129542144] |
    |[2024-05-13 02:00:00, A, 60129542144] |
    |[2024-05-14 01:00:00, D, 163208757248]|
    |[2024-05-14 01:32:00, D, 128849018880]|
    |[2024-05-14 03:30:00, A, 94489280512] |
    |[2024-05-14 23:30:00, A, 128849018880]|
    |[2024-05-15 01:30:00, A, 163208757248]|
    +--------------------------------------+
    
    +-------------------+----------+----+------------+---------+
    |timestamp          |long_ts   |type|value       |dept_flag|
    +-------------------+----------+----+------------+---------+
    |2024-05-05 03:00:00|1714858200|D   |94489280512 |1        |
    |2024-05-10 02:00:00|1715286600|D   |25769803776 |1        |
    |2024-05-10 21:30:00|1715356800|A   |25769803776 |0        |
    |2024-05-12 12:00:00|1715495400|D   |60129542144 |1        |
    |2024-05-13 02:00:00|1715545800|A   |60129542144 |0        |
    |2024-05-14 01:00:00|1715628600|D   |163208757248|1        |
    |2024-05-14 01:32:00|1715630520|D   |128849018880|1        |
    |2024-05-14 03:30:00|1715637600|A   |94489280512 |0        |
    |2024-05-14 23:30:00|1715709600|A   |128849018880|0        |
    |2024-05-15 01:30:00|1715716800|A   |163208757248|0        |
    +-------------------+----------+----+------------+---------+
    
    +-------------------+----------+----+------------+---------+--------+--------+
    |timestamp          |long_ts   |type|value       |dept_flag|calc_t12|calc_t24|
    +-------------------+----------+----+------------+---------+--------+--------+
    |2024-05-05 03:00:00|1714858200|D   |94489280512 |1        |NULL    |NULL    |
    |2024-05-10 02:00:00|1715286600|D   |25769803776 |1        |NULL    |NULL    |
    |2024-05-10 21:30:00|1715356800|A   |25769803776 |0        |0       |1       |
    |2024-05-12 12:00:00|1715495400|D   |60129542144 |1        |NULL    |NULL    |
    |2024-05-13 02:00:00|1715545800|A   |60129542144 |0        |0       |1       |
    |2024-05-14 01:00:00|1715628600|D   |163208757248|1        |NULL    |NULL    |
    |2024-05-14 01:32:00|1715630520|D   |128849018880|1        |NULL    |NULL    |
    |2024-05-14 03:30:00|1715637600|A   |94489280512 |0        |2       |2       |
    |2024-05-14 23:30:00|1715709600|A   |128849018880|0        |0       |2       |
    |2024-05-15 01:30:00|1715716800|A   |163208757248|0        |0       |1       |
    +-------------------+----------+----+------------+---------+--------+--------+
    
    +---------------+---------------+--------+--------+
    |ground_truth_12|ground_truth_24|calc_t12|calc_t24|
    +---------------+---------------+--------+--------+
    |0              |1              |0       |1       |
    |0              |1              |0       |1       |
    |2              |2              |2       |2       |
    |0              |2              |0       |2       |
    |0              |1              |0       |1       |
    +---------------+---------------+--------+--------+