Search code examples
javaapache-sparktimeoutspark-streaming

Spark streaming mapWithState timeout delayed?


I expected the new mapWithState API for Spark 1.6+ to near-immediately remove objects that are timed-out, but there is a delay.

I'm testing the API with the adapted version of the JavaStatefulNetworkWordCount below:

SparkConf sparkConf = new SparkConf()
    .setAppName("JavaStatefulNetworkWordCount")
    .setMaster("local[*]");

JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, Durations.seconds(1));
ssc.checkpoint("./tmp");

StateSpec<String, Integer, Integer, Tuple2<String, Integer>> mappingFunc =  
    StateSpec.function((word, one, state) -> {
        if (state.isTimingOut())
        {
             System.out.println("Timing out the word: " + word);
             return new Tuple2<String,Integer>(word, state.get());
        }
        else
        {
            int sum = one.or(0) + (state.exists() ? state.get() : 0);
            Tuple2<String, Integer> output = new Tuple2<String, Integer>(word, sum);
            state.update(sum);
            return output;
        }
});

JavaMapWithStateDStream<String, Integer, Integer, Tuple2<String, Integer>> stateDstream =
    ssc.socketTextStream(args[0], Integer.parseInt(args[1]),
     StorageLevels.MEMORY_AND_DISK_SER_2)
       .flatMap(x -> Arrays.asList(SPACE.split(x)))
       .mapToPair(w -> new Tuple2<String, Integer>(w, 1))
       .mapWithState(mappingFunc.timeout(Durations.seconds(5)));

stateDstream.stateSnapshots().print();

Together with nc (nc -l -p <port>)

When I type a word into the nc window I see the tuple being printed in the console every second. But it doesn't seem like the timing out message gets printed out 5s later, as expected based on the timeout set. The time it takes for the tuple to expire seems to vary between 5 & 20s.

Am I missing some configuration option, or is the timeout perhaps only performed at the same time as checkpoints?


Solution

  • Once an event times out it's NOT deleted right away, but is only marked for deletion by saving it to a 'deltaMap':

    override def remove(key: K): Unit = {
      val stateInfo = deltaMap(key)
      if (stateInfo != null) {
        stateInfo.markDeleted()
      } else {
        val newInfo = new StateInfo[S](deleted = true)
        deltaMap.update(key, newInfo)
      }
    }
    

    Then, timed out events are collected and sent to the output stream only at checkpoint. That is: events which time out at batch t, will appear in the output stream only at the next checkpoint - by default, after 5 batch-intervals on average, i.e. batch t+5:

     override def checkpoint(): Unit = {
        super.checkpoint()
        doFullScan = true
      }
    
    ...
    
    removeTimedoutData = doFullScan // remove timedout data only when full scan is enabled
    
    ...
    
    // Get the timed out state records, call the mapping function on each and collect the
        // data returned
        if (removeTimedoutData && timeoutThresholdTime.isDefined) {
    ...
    

    Elements are actually removed only when there are enough of them, and when state map is being serialized - which currently also happens only at checkpoint:

      /** Whether the delta chain length is long enough that it should be compacted */
      def shouldCompact: Boolean = {
        deltaChainLength >= deltaChainThreshold
      }
      // Write the data in the parent state map while copying the data into a new parent map for
        // compaction (if needed)
        val doCompaction = shouldCompact
    ...
    

    By default checkpointing occurs every 10 iterations, thus in the example above every 10 seconds; since your timeout is 5 seconds, events are expected within 5-15 seconds.

    EDIT: Corrected and elaborated answer following comments by @YuvalItzchakov