Search code examples
dataframeapache-sparkpysparkdatabricksspark-structured-streaming

Running a function for every batch of Pyspark sctructred stream without UDF


I have to run the following Pyspark code. I am reading from eventhub, transforming the data using multiple functions (dataframe transformation) and writing the dataframe to a directory. The update_session_id function has to run for each batch, but it is not working on the data from eventhub. It just has to update a lookup table which is referenced in the transform_raw_data function, if the current_timestamp is greater than the 2 hours from the timestamp maintained in the lookup table.

How can I implement this? Currently, the update_session_id function just executes once and then doesn't execute through out the lifetime of the stream.

df = spark.readStream.format("eventhubs").options(**conf).load() #Reading from eventhub

update_session_id(session_length, db_table) #function to update session value. Has to run for each batch or every hour

df = transform_raw_data(df, db_table) #tranforming the function

df = filter_countries(df=df, country_list=COUNTRY_CODE_ACCEPTLIST)

df = map_vehicle_type(df)

df = df_to_json(df, output_column=DATA_COLUMN)

df.writeStream  \
     .format("delta")  \
     .outputMode("append")  \
     .partitionBy("YYYYMMDD","hour") \
     .option("checkpointLocation", "BASE_PATH_RAW/CHECKPOINT_REPORTING_RAW_LOCATION")  \
     .start("BASE_PATH_RAW/REPORTING_RAW_LOCATION")

Solution

  • You can achieve this using the foreachBatch function that will be executed for each microbatch. In your case it could look as following:

    def my_foreach_batch(df, epoch_id):
      update_session_id(session_length, db_table)
      df.write.format("delta").mode("append") \
        .partitionBy("YYYYMMDD","hour") \
        .save("BASE_PATH_RAW/REPORTING_RAW_LOCATION")
    
    df.writeStream  \
      .outputMode("append")  \
      .foreachBatch(my_foreach_batch) \
      .option("checkpointLocation", "BASE_PATH_RAW/CHECKPOINT_REPORTING_RAW_LOCATION") \
      .start()
    

    Please note that by default foreachBatch isn't idempotent, and could be called several times, for example, when stream is restarted, and depending on complexity of operations, this may lead to duplicate appends. On Databricks Runtimes >= 8.4 you can guard against that by using idempotent writes into Delta tables.