Search code examples
scalaakkaakka-stream

Processing an Akka Stream with a One-Time Header


I have an application which receives a TCP socket connection which will send data in the form:

n{json}bbbbbbbbbb...

where n is the length of the following json in bytes, and the json might be something like {'splitEvery': 5}, which will dictate how I break up and process the potentially infinite string of bytes to follow.

I want to process this stream with Akka in Scala. I think streams are the right tool for this, but I am having a hard time finding an example that uses streams with distinct processing stages. Most stream flows seem to do the same thing over and over, like the prefixAndTail example here. That is very close to how I want to process the n{json} part of my stream, but the difference is I only need to do this once per connection and then move on to a different "stage" of processing.

Can anyone point me to an example of using Akka streams with distinct stages?


Solution

  • Here's a GraphStage which processes a stream of ByteStrings:

    • Extract chunk size from header
    • Emit ByteStrings of the specified chunk size
    import akka.stream.{Attributes, FlowShape, Inlet, Outlet}
    import akka.stream.stage.{GraphStage, GraphStageLogic, InHandler, OutHandler}
    import akka.util.ByteString
    
    class PreProcessor extends GraphStage[FlowShape[ByteString, ByteString]] {
    
      val in: Inlet[ByteString] = Inlet("ParseHeader.in")
      val out: Outlet[ByteString] = Outlet("ParseHeader.out")
    
      override val shape = FlowShape.of(in, out)
    
      override def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
        new GraphStageLogic(shape) {
    
          var buffer = ByteString.empty
          var chunkSize: Option[Int] = None
          private var upstreamFinished = false
    
          private val headerPattern = """^\d+\{"splitEvery": (\d+)\}""".r
    
          /**
            * @param data The data to parse.
            * @return The chunk size and header size if the header
            * could be parsed.
            */
          def parseHeader(data: ByteString): Option[(Int, Int)] =
          headerPattern.
            findFirstMatchIn(data.decodeString("UTF-8")).
            map { mtch => (mtch.group(1).toInt, mtch.end) }
    
          setHandler(out, new OutHandler {
            override def onPull(): Unit = {
              if (isClosed(in)) emit()
              else pull(in)
            }
          })
    
          setHandler(in, new InHandler {
            override def onPush(): Unit = {
              val elem = grab(in)
              buffer ++= elem
              if (chunkSize.isEmpty) {
                parseHeader(buffer) foreach { case (chunk, headerSize) =>
                  chunkSize = Some(chunk)
                  buffer = buffer.drop(headerSize)
                }
              }
              emit()
            }
    
            override def onUpstreamFinish(): Unit = {
              upstreamFinished = true
              if (chunkSize.isEmpty || buffer.isEmpty) completeStage()
              else {
                if (isAvailable(out)) emit()
              }
            }
          })
    
          private def continue(): Unit =
            if (isClosed(in)) completeStage()
            else pull(in)
    
          private def emit(): Unit = {
            chunkSize match {
              case None => continue()
              case Some(size) =>
                if (upstreamFinished && buffer.isEmpty ||
                   !upstreamFinished && buffer.size < size) {
                  continue()
                } else {
                  val (chunk, nextBuffer) = buffer.splitAt(size)
                  buffer = nextBuffer
                  push(out, chunk)
                }
            }
          }
        }
    }
    

    And the test case to illustrate the usage:

    import akka.actor.ActorSystem
    import akka.stream._
    import akka.stream.scaladsl.Source
    import akka.util.ByteString
    import org.scalatest._
    
    import scala.concurrent.Await
    import scala.concurrent.duration._
    import scala.util.Random
    
    class PreProcessorSpec extends FlatSpec {
    
      implicit val system = ActorSystem("Test")
      implicit val materializer = ActorMaterializer()
    
      val random = new Random
    
      "" should "" in {
    
        def splitRandom(s: String, n: Int): List[String] = s match {
          case "" => Nil
          case s =>
            val (head, tail) = s splitAt random.nextInt(n)
            head :: splitRandom(tail, n)
        }
    
        val input = """17{"splitEvery": 5}aaaaabbbbbcccccddd"""
    
        val strings = splitRandom(input, 7)
        println(strings.map(s => s"[$s]").mkString(" ") + "\n")
    
        val future = Source.fromIterator(() => strings.iterator).
          map(ByteString(_)).
          via(new PreProcessor()).
          map(_.decodeString("UTF-8")).
          runForeach(println)
    
        Await.result(future, 5 seconds)
      }
    
    }
    

    Example output:

    [17{"] [splitE] [very"] [] [: 5}] [aaaaa] [bbb] [bbcccc] [] [cddd]
    
    aaaaa
    bbbbb
    ccccc
    ddd