Search code examples
pyspark

How to aggregate pyspark based on values in consecutive rows


I have input dataframe which has 3 columns Time, Name , Flag. I would like to aggregate into a start and end columns where the Name and Flag have the same value.

Input data frame

Time Name Flag
5/1/2023 1:01 Peter 1
5/1/2023 1:02 Peter 1
5/1/2023 1:03 Peter 1
5/1/2023 1:04 Peter 0
5/1/2023 1:05 Peter 0
5/1/2023 1:06 Peter 1
5/1/2023 1:07 Peter 1
5/1/2023 1:08 Peter 1
5/1/2023 1:01 John 1
5/1/2023 1:02 John 0
5/1/2023 1:03 John 0
5/1/2023 1:04 John 0
5/1/2023 1:05 John 0
5/1/2023 1:06 John 0
5/1/2023 1:07 John 1
5/1/2023 1:08 John 1
5/2/2023 1:10 Peter 1
5/2/2023 1:11 Peter 1
5/2/2023 1:20 John 0
5/2/2023 1:21 John 0
5/2/2023 1:22 John 0

Output data frame

Start End Name Flag
5/1/2023 1:01 5/1/2023 1:03 Peter 1
5/1/2023 1:04 5/1/2023 1:05 Peter 0
5/1/2023 1:06 5/1/2023 1:08 Peter 1
5/2/2023 1:10 5/2/2023 1:11 Peter 1
5/1/2023 1:01 5/1/2023 1:01 John 1
5/1/2023 1:02 5/1/2023 1:06 John 0
5/1/2023 1:07 5/1/2023 1:08 John 1
5/2/2023 1:20 5/2/2023 1:22 John 0

In this case, consecutive rows means consecutive in time.

1:08 and 1:10 is not combined because there is a gap (missing 1:09) between the rows 1:08 and 1:10

Can you please tell me how can I do that?


Solution

  • First, you want to create groupings that meet your condition. To create it, a general tip is to create a flag that has 1 on when you want to separate a group and 0 when you want to combine to previous. Then, cumsum over this flag will result in the groupings that you want.

    Your conditions are

    from pyspark.sql import functions as F
    # covert Time to timestamp
    df = df.withColumn('timestamp', F.to_timestamp('Time', 'M/d/yyyy H:mm'))
    
    w = Window.partitionBy('Name').orderBy('timestamp')
    
    # If previous Flag is different from current Flag
    F.lag('Flag').over(w) != F.col('Flag'))
    
    # OR previous timestamp is more than 1 minute ago
    | (((F.col('timestamp').cast('long') - F.lag('timestamp').over(w).cast('long')) / 60) > 1)
    

    With these conditions, create the groupings as grp column and use the column to aggregate.

    w = Window.partitionBy('Name').orderBy('timestamp')
    df = (df.withColumn('timestamp', F.to_timestamp('Time', 'M/d/yyyy H:mm'))
          .withColumn('grp', (F.lag('Flag').over(w).isNull() 
                              | (F.lag('Flag').over(w) != F.col('Flag'))
                              | (((F.col('timestamp').cast('long') - F.lag('timestamp').over(w).cast('long')) / 60) > 1)).cast('int'))
          .withColumn('grp', F.sum('grp').over(w))
          .groupby('Name', 'grp')
          .agg(F.min('Time').alias('Start'), F.max('Time').alias('End'), F.first('Flag').alias('Flag')))