Search code examples
scalaapache-flink

How to unit test BroadcastProcessFunction in flink when processElement depends on broadcasted data


I implemented a flink stream with a BroadcastProcessFunction. From the processBroadcastElement I get my model and I apply it on my event in processElement.

I don't find a way to unit test my stream as I don't find a solution to ensure the model is dispatched prior to the first event. I would say there are two ways for achieving this:
1. Find a solution to have the model pushed in the stream first
2. Have the broadcast state filled with the model prio to the execution of the stream so that it is restored

I may have missed something, but I have not found an simple way to do this.

Here is a simple unit test with my issue:

import org.apache.flink.api.common.state.MapStateDescriptor
import org.apache.flink.streaming.api.functions.co.BroadcastProcessFunction
import org.apache.flink.streaming.api.functions.sink.SinkFunction
import org.apache.flink.streaming.api.scala._
import org.apache.flink.util.Collector
import org.scalatest.Matchers._
import org.scalatest.{BeforeAndAfter, FunSuite}

import scala.collection.mutable


class BroadCastProcessor extends BroadcastProcessFunction[Int, (Int, String), String] {

  import BroadCastProcessor._

  override def processElement(value: Int,
                              ctx: BroadcastProcessFunction[Int, (Int, String), String]#ReadOnlyContext,
                              out: Collector[String]): Unit = {
    val broadcastState = ctx.getBroadcastState(broadcastStateDescriptor)

    if (broadcastState.contains(value)) {
      out.collect(broadcastState.get(value))
    }
  }

  override def processBroadcastElement(value: (Int, String),
                                       ctx: BroadcastProcessFunction[Int, (Int, String), String]#Context,
                                       out: Collector[String]): Unit = {
    ctx.getBroadcastState(broadcastStateDescriptor).put(value._1, value._2)
  }
}

object BroadCastProcessor {
  val broadcastStateDescriptor: MapStateDescriptor[Int, String] = new MapStateDescriptor[Int, String]("int_to_string", classOf[Int], classOf[String])
}

class CollectSink extends SinkFunction[String] {

  import CollectSink._

  override def invoke(value: String): Unit = {
    values += value
  }
}

object CollectSink { // must be static
  val values: mutable.MutableList[String] = mutable.MutableList[String]()
}

class BroadCastProcessTest extends FunSuite with BeforeAndAfter {

  before {
    CollectSink.values.clear()
  }

  test("add_elem_to_broadcast_and_process_should_apply_broadcast_rule") {
    val env = StreamExecutionEnvironment.getExecutionEnvironment
    env.setParallelism(1)

    val dataToProcessStream = env.fromElements(1)

    val ruleToBroadcastStream = env.fromElements(1 -> "1", 2 -> "2", 3 -> "3")

    val broadcastStream = ruleToBroadcastStream.broadcast(BroadCastProcessor.broadcastStateDescriptor)

    dataToProcessStream
      .connect(broadcastStream)
      .process(new BroadCastProcessor)
      .addSink(new CollectSink())

    // execute
    env.execute()

    CollectSink.values should contain("1")
  }
}

Update thanks to David Anderson
I went for the buffer solution. I defined a process function for the synchronization:

class SynchronizeModelAndEvent(modelNumberToWaitFor: Int) extends CoProcessFunction[Int, (Int, String), Int] {
  val eventBuffer: mutable.MutableList[Int] = mutable.MutableList[Int]()
  var modelEventsNumber = 0

  override def processElement1(value: Int, ctx: CoProcessFunction[Int, (Int, String), Int]#Context, out: Collector[Int]): Unit = {
    if (modelEventsNumber < modelNumberToWaitFor) {
      eventBuffer += value
      return
    }
    out.collect(value)
  }

  override def processElement2(value: (Int, String), ctx: CoProcessFunction[Int, (Int, String), Int]#Context, out: Collector[Int]): Unit = {
    modelEventsNumber += 1

    if (modelEventsNumber >= modelNumberToWaitFor) {
      eventBuffer.foreach(event => out.collect(event))
    }
  }
}

And so I need to add it to my stream:

dataToProcessStream
  .connect(ruleToBroadcastStream)
  .process(new SynchronizeModelAndEvent(3))
  .connect(broadcastStream)
  .process(new BroadCastProcessor)
  .addSink(new CollectSink())

Thanks


Solution

  • There isn't an easy way to do this. You could have processElement buffer all of its input until the model has been received by processBroadcastElement. Or run the job once with no event traffic and take a savepoint once the model has been broadcast. Then restore that savepoint into the same job, but with its event input connected.

    By the way, the capability you are looking for is often referred to as "side inputs" in the Flink community.