Search code examples
pythonapache-sparkpysparkspark-structured-streaming

is there a way to do custom Window not time based on Kafka stream, using Pyspark


I have a Kafka stream sending me data of heartbeat for cyclist on circuit. I need to able to do AVG of heartbeat for each lap he did. I tried to use the session but it only works on time and in my case the time could be different each lap.

I found with foreachBatch I can create a Window on any column.

.foreachBatch(calculate_heartbeat) 

and in this function:

def calculate_heartbeat(df, batch_id):
    lap_window = Window.partitionBy("lap")
    df = df.withColumn("avg", avg("heartbeat").over(lap_window))
    df.show(truncate=False)
    df.groupBy("lap").agg(avg("heartbeat")).show()

    return df

but when using the foreachBatch I a not able to accumulate the whole data of the lap, is there a way to do it ?

I tried different approach to create and empty dataframe and add each batch I received to it I am expecting to get accumulate the whole lap in the dataframe, or any better approach to do my calculation ? or windows the lap


Solution

  • I'm using Spark 3.4.0 for this test.

    I have generated some CSVs in /content/input directory with the following content which I understand will be present in your event stream:

    lapId,heartbeat,timestamp
    1,122,2023-05-23 10:01:00
    1,132,2023-05-23 10:02:00
    2,137,2023-05-23 10:03:00
    2,122,2023-05-23 10:04:00
    2,132,2023-05-23 10:05:00
    3,137,2023-05-23 10:06:00
    3,122,2023-05-23 10:07:00
    3,132,2023-05-23 10:08:00
    4,137,2023-05-23 10:09:00
    

    Using the session window aggregation functionality of Spark Structured Streaming where in this case it is assumed that one event will be reported at least every 5 minutes (otherwise the lap is considered as over)

    from pyspark.sql import SparkSession
    from pyspark.sql.types import StructType,StructField, TimestampType,StringType,LongType
    from pyspark.sql.functions import session_window, avg
    
    spark = SparkSession.builder.master("local[*]").getOrCreate()
    
    schema = StructType([
      StructField('lapId', StringType(), True),
      StructField('heartbeat', LongType(), True),
      StructField('timestamp', TimestampType(), True)
    ])
    
    df = spark.readStream.format("csv").schema(schema).option("header",True).load("/content/input")
    
    # This is the part that interests you
    avg_heartbeat_rate_per_lap = df \
        .withWatermark("timestamp", "10 minutes") \
        .groupBy(
            session_window(df.timestamp, "5 minutes"),
            df.lapId) \
        .agg(avg("heartbeat"))
    
    query = avg_heartbeat_rate_per_lap \
        .writeStream \
        .outputMode("complete") \
        .queryName("aggregates") \
        .format("memory") \
        .start()
    
    spark.sql("select * from aggregates").show(truncate=False)
    

    The results are correct as per the inputs:

    +------------------------------------------+-----+------------------+
    |session_window                            |lapId|avg(heartbeat)    |
    +------------------------------------------+-----+------------------+
    |{2023-05-23 10:06:00, 2023-05-23 10:13:00}|3    |130.33333333333334|
    |{2023-05-23 10:01:00, 2023-05-23 10:07:00}|1    |127.0             |
    |{2023-05-23 10:09:00, 2023-05-23 10:14:00}|4    |137.0             |
    |{2023-05-23 10:03:00, 2023-05-23 10:10:00}|2    |130.33333333333334|
    +------------------------------------------+-----+------------------+