Search code examples
pythonapache-sparkpysparkapache-spark-sqlwindow-functions

Pyspark window function to calculate number of transits between stops


I am using Pyspark and I would like to create a function which performs the following operation:

Given data describing the transactions of train users:

+----+-------------------+-------+---------------+-------------+-------------+
|USER|       DATE        |LINE_ID|      STOP     | TOPOLOGY_ID |TRANSPORT_ID |
+------------------------+-------+---------------+-------------+-------------+
|John|2021-01-27 07:27:34|      7| King Cross    |       171235|       03    |
|John|2021-01-27 07:28:00|     40| White Chapell |       123582|       03    |  
|John|2021-01-27 07:35:30|      4| Reaven        |       171565|       03    |  
|Tom |2021-01-27 07:27:23|      7| King Cross    |       171235|       03    |    
|Tom |2021-01-27 07:28:30|     40| White Chapell |       123582|       03    |                   
+----+-------------------+-------+---------------+-------------+-------------+

I would like to get the number of times a combination of stops A-B, B-C, etc. have been made in a grouped of 30 minutes.

So, let's say user "John" goes from stop "King Cross" to "White Chapell" at 7:27 and then goes from "White Chapell" to "Reaven" at 7:35.
Meanwhile, "Tom" goes from "King Cross" to "White Chapell" at 7:27 and then from "White Chapell" to "Oxford Circus" at 7:32.

The result of the opration would hae to be something like:

+----------------------+-----------------+---------------+-----------+
|          DATE        |   ORIG_STOP     |   DEST_STOP   | NUM_TRANS |
+----------------------+-----------------+---------------+-----------+
|   2021-01-27 07:00:00|  King Cross     | White Chapell |       2   |
|   2021-01-27 07:30:00|  White Chapell  | Reaven        |       1   |              
+----------------------+-----------------+---------------+-----------+

I have tried using window functions, but I can't manage to get what I really want.


Solution

  • You may try running the following

    Using Spark SQL

    Within the first CTE initial_stop_groups it determines the related ORIGIN and DESTINATION stops and times with the LEAD function. The next CTE stop_groups, determines the associated 30 minute intervals using CASE expressions and date functions and filters out non-groups (i.e. no stop destinations). The final projection then uses a group by to aggregrate on the time interval, origin and destination groups to count the resulting NUM_TRANS where there are within the same 30 minute interval.

    Assuming your data is in input_df

    input_df.createOrReplaceTempView("input_df")
    
    output_df = sparkSession.sql("""
     WITH initial_stop_groups AS (
            SELECT
                DATE as ORIG_DATE,
                LEAD(DATE) OVER (
                    PARTITION BY USER,TRANSPORT_ID
                    ORDER BY DATE
                ) as STOP_DATE,
                STOP as ORIG_STOP,
                LEAD(STOP) OVER (
                    PARTITION BY USER,TRANSPORT_ID
                    ORDER BY DATE
                ) as DEST_STOP
            FROM
                input_df
        ),
        stop_groups AS (
            SELECT 
                CAST(CONCAT(
                  CAST(ORIG_DATE as DATE),
                  ' ',
                  hour(ORIG_DATE),
                  ':',
                  CASE WHEN minute(ORIG_DATE) < 30 THEN '00' ELSE '30' END,
                  ':00'
                ) AS TIMESTAMP) as ORIG_TIME,
                CASE WHEN STOP_DATE IS NOT NULL THEN CAST(CONCAT(
                  CAST(STOP_DATE as DATE),
                  ' ',
                  hour(STOP_DATE),
                  ':',
                  CASE WHEN minute(STOP_DATE) < 30 THEN '00' ELSE '30' END,
                  ':00'
                ) AS TIMESTAMP) ELSE NULL END as STOP_TIME,
                ORIG_STOP,
                DEST_STOP
            FROM 
                initial_stop_groups
            WHERE
                DEST_STOP IS NOT NULL
        )
        SELECT
            STOP_TIME as DATE, 
            ORIG_STOP,
            DEST_STOP,
            COUNT(1) as NUM_TRANS
        FROM
            stop_groups
        WHERE
            (unix_timestamp(STOP_TIME) - unix_timestamp(ORIG_TIME)) <=30*60
            
        GROUP BY
            STOP_TIME, ORIG_STOP, DEST_STOP;
        
    """)
    
    output_df.show()
    
    DATE orig_stop dest_stop num_trans
    2021-01-27T07:00:00.000Z King Cross White Chapell 2
    2021-01-27T07:30:00.000Z White Chapell Reaven 1

    View on DB Fiddle

    • CAST((STOP_TIME - ORIG_TIME) as STRING) IN ('0 seconds','30 minutes') was replaced by (unix_timestamp(STOP_TIME) - unix_timestamp(ORIG_TIME)) <=30*60

    Using spark API

    Actual code

    from pyspark.sql import functions as F
    from pyspark.sql import Window
    
    next_stop_window = Window().partitionBy("USER","TRANSPORT_ID").orderBy("DATE")
    
    output_df = (
        input_df.select(
            F.col("DATE").alias("ORIG_DATE"),
            F.lead("DATE").over(next_stop_window).alias("STOP_DATE"),
            F.col("STOP").alias("ORIG_STOP"),
            F.lead("STOP").over(next_stop_window).alias("DEST_STOP"),
        ).where(
            F.col("DEST_STOP").isNotNull()
        ).select(
            F.concat(
                F.col("ORIG_DATE").cast("DATE"),
                F.lit(' '),
                F.hour("ORIG_DATE"),
                F.lit(':'),
                F.when(
                    F.minute("ORIG_DATE") < 30, '00'
                ).otherwise('30'),
                F.lit(':00')
            ).cast("TIMESTAMP").alias("ORIG_TIME"),
            F.concat(
                F.col("STOP_DATE").cast("DATE"),
                F.lit(' '),
                F.hour("STOP_DATE"),
                F.lit(':'),
                F.when(
                    F.minute("STOP_DATE") < 30, '00'
                ).otherwise('30'),
                F.lit(':00')
            ).cast("TIMESTAMP").alias("STOP_TIME"),
            F.col("ORIG_STOP"),
            F.col("DEST_STOP")
        ).where(
            (F.unix_timestamp("STOP_TIME") - F.unix_timestamp("ORIG_TIME")) <= 30*60
            # (F.col("STOP_TIME")-F.col("ORIG_TIME")).cast("STRING").isin(['0 seconds','30 minutes'])
        ).groupBy(
            F.col("STOP_TIME"),
            F.col("ORIG_STOP"),
            F.col("DEST_STOP"),
        ).count().select(
            F.col("STOP_TIME").alias("DATE"),
            F.col("ORIG_STOP"),
            F.col("DEST_STOP"),
            F.col("count").alias("NUM_TRANS"),
        )
        
    )
    output_df.show()
    
    
    DATE orig_stop dest_stop num_trans
    2021-01-27T07:00:00.000Z King Cross White Chapell 2
    2021-01-27T07:30:00.000Z White Chapell Reaven 1

    Resulting Schema

    output_df.printSchema()
    
    root
     |-- DATE: timestamp (nullable = true)
     |-- ORIG_STOP: string (nullable = true)
     |-- DEST_STOP: string (nullable = true)
     |-- NUM_TRANS: long (nullable = false)
    

    Setup code for reproducibility

    data="""+----+-------------------+-------+---------------+-------------+-------------+
    |USER|       DATE        |LINE_ID|      STOP     | TOPOLOGY_ID |TRANSPORT_ID |
    +------------------------+-------+---------------+-------------+-------------+
    |John|2021-01-27 07:27:34|      7| King Cross    |       171235|       03    |
    |John|2021-01-27 07:28:00|     40| White Chapell |       123582|       03    |  
    |John|2021-01-27 07:35:30|      4| Reaven        |       171565|       03    |  
    |Tom |2021-01-27 07:27:23|      7| King Cross    |       171235|       03    |    
    |Tom |2021-01-27 07:28:30|     40| White Chapell |       123582|       03    |                   
    +----+-------------------+-------+---------------+-------------+-------------+
    """
    
    rows = [ [ pc.strip() for pc in line.strip().split("|")[1:-1]] for line in data.strip().split("\n")[3:-1]]
    headers = [pc.strip() for pc in data.strip().split("\n")[1].split("|")[1:-1]]
    
    from pyspark.sql import functions as F
    input_df = sparkSession.createDataFrame(rows,schema=headers)
    input_df = input_df.withColumn("DATE",F.col("DATE").cast("TIMESTAMP"))
    
    
    

    Let me know if this works for you.