Search code examples
scalaapache-sparkapache-spark-sqlspark-streamingspark-structured-streaming

Stream-Static Join: How to refresh (unpersist/persist) static Dataframe periodically


I am building a Spark Structured Streaming application where I am doing a batch-stream join. And the source for the batch data gets updated periodically.

So, I am planning to do a persist/unpersist of that batch data periodically.

Below is a sample code which I am using to persist and unpersist the batch data.

Flow:

  • Read the batch data
  • persist the batch data
  • For every one hour, unpersist the data and read the batch data and persist it again.

But, I am not seeing the batch data getting refreshed for every hour.

Code:

var batchDF = handler.readBatchDF(sparkSession)
batchDF.persist(StorageLevel.MEMORY_AND_DISK)
var refreshedTime: Instant = Instant.now()

if (Duration.between(refreshedTime, Instant.now()).getSeconds > refreshTime) {
  refreshedTime = Instant.now()
  batchDF.unpersist(false)
  batchDF =  handler.readBatchDF(sparkSession)
    .persist(StorageLevel.MEMORY_AND_DISK)
}

Is there any better way to achieve this scenario in spark structured streaming jobs ?


Solution

  • You could do this by making use of the streaming scheduling capabilities that Structured Streaming provides.

    You can trigger the refreshing (unpersist -> load -> persist) of a static Dataframe by creating an artificial "Rate" stream that refreshes the static Dataframe periodically. The idea is to:

    1. Load the static Dataframe initially and keep as var
    2. Define a method that refreshes the static Dataframe
    3. Use a "Rate" Stream that gets triggered at the required interval (e.g. 1 hour)
    4. Read actual streaming data and perform join operation with static Dataframe
    5. Within that Rate Stream have a foreachBatch sink that calls refresher method created in step 2.

    The following code runs fine with Spark 3.0.1, Scala 2.12.10 and Delta 0.7.0.

      // 1. Load the staticDataframe initially and keep as `var`
      var staticDf = spark.read.format("delta").load(deltaPath)
      staticDf.persist()
    
      //  2. Define a method that refreshes the static Dataframe
      def foreachBatchMethod[T](batchDf: Dataset[T], batchId: Long) = {
        staticDf.unpersist()
        staticDf = spark.read.format("delta").load(deltaPath)
        staticDf.persist()
        println(s"${Calendar.getInstance().getTime}: Refreshing static Dataframe from DeltaLake")
      }
    
      // 3. Use a "Rate" Stream that gets triggered at the required interval (e.g. 1 hour)
      val staticRefreshStream = spark.readStream
        .format("rate")
        .option("rowsPerSecond", 1)
        .option("numPartitions", 1)
        .load()
        .selectExpr("CAST(value as LONG) as trigger")
        .as[Long]
    
      // 4. Read actual streaming data and perform join operation with static Dataframe
      // As an example I used Kafka as a streaming source
      val streamingDf = spark.readStream
        .format("kafka")
        .option("kafka.bootstrap.servers", "localhost:9092")
        .option("subscribe", "test")
        .option("startingOffsets", "earliest")
        .option("failOnDataLoss", "false")
        .load()
        .selectExpr("CAST(value AS STRING) as id", "offset as streamingField")
    
      val joinDf = streamingDf.join(staticDf, "id")
    
      val query = joinDf.writeStream
        .format("console")
        .option("truncate", false)
        .option("checkpointLocation", "/path/to/sparkCheckpoint")
        .start()
    
      // 5. Within that Rate Stream have a `foreachBatch` sink that calls refresher method
      staticRefreshStream.writeStream
        .outputMode("append")
        .foreachBatch(foreachBatchMethod[Long] _)
        .queryName("RefreshStream")
        .trigger(Trigger.ProcessingTime("5 seconds")) // or e.g. 1 hour
        .start()
    

    To have a full example, the delta table got created and updated with new values as below:

      val deltaPath = "file:///tmp/delta/table"
    
      import spark.implicits._
      val df = Seq(
        (1L, "static1"),
        (2L, "static2")
      ).toDF("id", "deltaField")
    
      df.write
        .mode(SaveMode.Overwrite)
        .format("delta")
        .save(deltaPath)