Search code examples
scalasessionapache-flinkflink-streaming

How to use window state with Flink 1.7.1 Session Windows


I am building a Flink app with the following goals:

  1. Gather events into keyed inactivity triggered session windows
  2. Emit a replica of input events with as early as possible, augmented with session reference
  3. Emit session updates when session is opened and closed along with gathered session statistics (at session close time)

I have been able to achieve the goals above, with Tumbling windows, but I have no success do the same with Session windows.

My window processing code is as follows

package io.github.streamingwithflink.jv

import java.util.{Calendar}

import io.github.streamingwithflink.util.{MySensorSource, SensorReading, SensorTimeAssigner}
import org.apache.flink.api.common.state.ValueStateDescriptor
import org.apache.flink.api.scala._
import org.apache.flink.streaming.api.datastream.DataStreamSink
import org.apache.flink.streaming.api.scala.function.ProcessWindowFunction
import org.apache.flink.streaming.api.{TimeCharacteristic}
import org.apache.flink.streaming.api.scala.{DataStream, OutputTag, StreamExecutionEnvironment}
import org.apache.flink.streaming.api.windowing.assigners.{EventTimeSessionWindows}
import org.apache.flink.streaming.api.windowing.time.Time
import org.apache.flink.streaming.api.windowing.triggers.{Trigger, TriggerResult}
import org.apache.flink.streaming.api.windowing.windows.TimeWindow
import org.apache.flink.util.Collector

import scala.util.Random

object MySessionWindow {

  def main(args: Array[String]): Unit = {

    // set up the streaming execution environment
    val env = StreamExecutionEnvironment.getExecutionEnvironment

    // use event time for the application
    env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime)
    // configure watermark interval
    env.getConfig.setAutoWatermarkInterval(1000L)

    // ingest sensor stream
    val sensorData: DataStream[SensorReading] = env
      // SensorSource generates random temperature readings
      .addSource(new MySensorSource)
      // assign timestamps and watermarks which are required for event time
      .assignTimestampsAndWatermarks(new SensorTimeAssigner)

    val sessionizedEvents = sensorData
      .keyBy(_.id)
      // a session window with 1.5 second gap
      .window(EventTimeSessionWindows.withGap(Time.milliseconds(1500)))
      // a custom trigger that fires every event received
      .trigger(new MyTrigger)
      // count readings per window
      .process(new MySessionWindowFunction)

    sessionizedEvents.print()

    // retrieve and print session output
    val sessionOutput: DataStreamSink[String] = sessionizedEvents
      .getSideOutput(new OutputTag[String]("session-status"))
      .print()

    env.execute()
  }
}

/** A trigger that fires with every event placed in window. */
class MyTrigger
    extends Trigger[SensorReading, TimeWindow] {

  override def onElement(
      r: SensorReading,
      timestamp: Long,
      window: TimeWindow,
      ctx: Trigger.TriggerContext): TriggerResult = {

    if (timestamp >= window.getEnd) {
      TriggerResult.FIRE_AND_PURGE
    }
    else  {
      TriggerResult.FIRE
    }
  }

  override def onEventTime(
                            timestamp: Long,
                            window: TimeWindow,
                            ctx: Trigger.TriggerContext): TriggerResult = {
    // Continue. not using event time timers
    TriggerResult.CONTINUE
  }

  override def onProcessingTime(
      timestamp: Long,
      window: TimeWindow,
      ctx: Trigger.TriggerContext): TriggerResult = {
    // Continue. We don't use processing time timers
    TriggerResult.CONTINUE
  }

  override def canMerge: Boolean = {
    return true
  }

  override def onMerge(
                        window: TimeWindow,
                        ctx: Trigger.OnMergeContext) = {
  }

  override def clear(
      window: TimeWindow,
      ctx: Trigger.TriggerContext): Unit = {
    // No trigger state to clear
  }
}

/** A window function that counts the readings per sensor and window.
  * The function emits the sensor id, session reference and temperature . */
class MySessionWindowFunction
  extends ProcessWindowFunction[SensorReading, (String, Int, Double), String, TimeWindow] {

  override def process(
      key: String,
      ctx: Context,
      readings: Iterable[SensorReading],
      out: Collector[(String, Int, Double)]): Unit = {

    // count readings
    val cnt = readings.count(_ => true)
    val curTime = Calendar.getInstance.getTimeInMillis
    val lastTime = readings.last.timestamp

    val sessionRefDesc = new ValueStateDescriptor[Int]("sessionRef", classOf[Int])
    val sessionRef = ctx.windowState.getState[Int](sessionRefDesc)
    val sessionCountDesc = new ValueStateDescriptor[Int]("sessionCount", classOf[Int])
    val sessionCount = ctx.windowState.getState[Int](sessionCountDesc)
    // Side output for session
    val sessionStatus: OutputTag[String] =
      new OutputTag[String]("session-status")
    // create a new sessionRef every time new window starts
    if (cnt == 1) {
      // set sessionRef for first element
      val sessionRefValue = new Random().nextInt(998) + 1
      sessionRef.update(sessionRefValue)
      ctx.output(sessionStatus, s"Session opened: ${readings.last.id}, ref:${sessionRef.value()}")
    }
    sessionCount.update(cnt)
    out.collect((readings.last.id, sessionRef.value(), readings.last.temperature))
  }

  override def clear(
                      ctx: Context): Unit = {
    // Clearing window session context
    val sessionRefDesc = new ValueStateDescriptor[Int]("sessionRef", classOf[Int])
    val sessionRef = ctx.windowState.getState[Int](sessionRefDesc)
    val sessionCountDesc = new ValueStateDescriptor[Int]("sessionCount", classOf[Int])
    val sessionCount = ctx.windowState.getState[Int](sessionCountDesc)
    // println(s"Clearing sessionRef ${sessionRef.value()}")
    // Side output for session
    val sessionOutput: OutputTag[String] =
      new OutputTag[String]("session-status")
    ctx.output(sessionOutput, s"Session closed: ref:${sessionRef.value()}, count:${sessionCount.value()}")
    sessionRef.clear()
    sessionCount.clear()
    super.clear(ctx)
  }
}

To generate the input I am using

package io.github.streamingwithflink.util

import java.util.Calendar

import org.apache.flink.streaming.api.functions.source.RichSourceFunction
import org.apache.flink.streaming.api.functions.source.SourceFunction.SourceContext

import scala.util.Random

/**
  * Flink SourceFunction to generate SensorReadings with random temperature values.
  *
  * Each parallel instance of the source simulates 1 sensor which emit one sensor
  * reading spaced by a progressive delay capped at 3 seconds (1,2,3,1,2,3,1...)
  */
class MySensorSource extends RichSourceFunction[SensorReading] {

  // flag indicating whether source is still running.
  var running: Boolean = true

  /** run() continuously emits SensorReadings by emitting them through the SourceContext. */
  override def run(srcCtx: SourceContext[SensorReading]): Unit = {

    // initialize random number generator
    val rand = new Random()
    // look up index of this parallel task
    val taskIdx = this.getRuntimeContext.getIndexOfThisSubtask

    // initialize sensor ids and temperatures
    var curFTemp = (1 to 1).map {  // Slow
      i => ("sensor_" + (taskIdx * 10 + i), 65 + (rand.nextGaussian() * 20));
    }

//    curFTemp.foreach(t => System.out.println(t._1, t._2))

    // emit data until being canceled
    var waitTime = 0;
    while (running) {
      // Progressive 1s delay, with 3s max: 1,2,3,1,2,3,1...
      waitTime = (waitTime) % 3000 + 1000
      // update temperature
      curFTemp = curFTemp.map( t => (t._1, t._2 + rand.nextGaussian() * 0.5) )
      // get current time
      val curTime = Calendar.getInstance.getTimeInMillis

      // emit new SensorReading
      curFTemp.foreach({t => srcCtx.collect(SensorReading(t._1, curTime, t._2))})
//      curFTemp.foreach(t => println(s"TX: id:${t._1}, ts:${curTime}, temp:${t._2}"))
      Thread.sleep(waitTime)
    }
  }

  /** Cancels this SourceFunction. */
  override def cancel(): Unit = {
    running = false
  }

}

and

case class SensorReading(id: String, timestamp: Long, temperature: Double)

When I execute with Session Windows I get the following Exception

Exception in thread "main" org.apache.flink.runtime.client.JobExecutionException: Job execution failed.
    at org.apache.flink.runtime.jobmaster.JobResult.toJobExecutionResult(JobResult.java:146)
    at org.apache.flink.runtime.minicluster.MiniCluster.executeJobBlocking(MiniCluster.java:647)
    at org.apache.flink.streaming.api.environment.LocalStreamEnvironment.execute(LocalStreamEnvironment.java:123)
    at org.apache.flink.streaming.api.environment.StreamExecutionEnvironment.execute(StreamExecutionEnvironment.java:1510)
    at org.apache.flink.streaming.api.scala.StreamExecutionEnvironment.execute(StreamExecutionEnvironment.scala:645)
    at io.github.streamingwithflink.jv.MySessionWindow$.main(MySessionWindow.scala:55)
    at io.github.streamingwithflink.jv.MySessionWindow.main(MySessionWindow.scala)
Caused by: java.lang.UnsupportedOperationException: Per-window state is not allowed when using merging windows.
    at org.apache.flink.streaming.runtime.operators.windowing.WindowOperator$MergingWindowStateStore.getState(WindowOperator.java:678)
    at io.github.streamingwithflink.jv.MySessionWindowFunction.process(MySessionWindow.scala:126)
    at io.github.streamingwithflink.jv.MySessionWindowFunction.process(MySessionWindow.scala:111)
    at org.apache.flink.streaming.api.scala.function.util.ScalaProcessWindowFunctionWrapper.process(ScalaProcessWindowFunctionWrapper.scala:63)
    at org.apache.flink.streaming.runtime.operators.windowing.functions.InternalIterableProcessWindowFunction.process(InternalIterableProcessWindowFunction.java:50)
    at org.apache.flink.streaming.runtime.operators.windowing.functions.InternalIterableProcessWindowFunction.process(InternalIterableProcessWindowFunction.java:32)
    at org.apache.flink.streaming.runtime.operators.windowing.WindowOperator.emitWindowContents(WindowOperator.java:546)
    at org.apache.flink.streaming.runtime.operators.windowing.WindowOperator.processElement(WindowOperator.java:370)
    at org.apache.flink.streaming.runtime.io.StreamInputProcessor.processInput(StreamInputProcessor.java:202)
    at org.apache.flink.streaming.runtime.tasks.OneInputStreamTask.run(OneInputStreamTask.java:105)
    at org.apache.flink.streaming.runtime.tasks.StreamTask.invoke(StreamTask.java:300)
    at org.apache.flink.runtime.taskmanager.Task.run(Task.java:704)
    at java.lang.Thread.run(Thread.java:748)

It am hoping I am missing a trick, as not being able to store state with Sessions Window feels very restrictive.

Any pointers would be greatly appreciated.


Solution

  • Session windows are indeed rather special. As each new event arrives it is initially assigned to its own window, after which the set of all current session windows is processed and any possible merges are performed (based on the session gap). This approach means that there isn't really a stable notion of the session to which a given event belongs, and it renders the concept of per-window state rather awkward -- and it isn't supported.

    You might be able to build a solution using session windows based on globalState, or by using a ProcessFunction instead of the window API.