Search code examples
scalaakka-stream

How to create an akka-stream Source from a Flow that generate values recursively?


I need to traverse an API that is shaped like a tree. For example, a directory structure or threads of discussion. It can be modeled via the following flow:

type ItemId = Int
type Data = String
case class Item(data: Data, kids: List[ItemId])

def randomData(): Data = scala.util.Random.alphanumeric.take(2).mkString 

// 0 => [1, 9]
// 1 => [10, 19]
// 2 => [20, 29]
// ...
// 9 => [90, 99]
// _ => []
// NB. I don't have access to this function, only the itemFlow.
def nested(id: ItemId): List[ItemId] =
  if (id == 0) (1 to 9).toList
  else if (1 <= id && id <= 9) ((id * 10) to ((id + 1) * 10 - 1)).toList
  else Nil

val itemFlow: Flow[ItemId, Item, NotUsed] = 
  Flow.fromFunction(id => Item(randomData, nested(id)))

How can I traverse this data? I got the following working:

import akka.NotUsed
import akka.actor.ActorSystem
import akka.stream._
import akka.stream.scaladsl._

import scala.concurrent.Await
import scala.concurrent.duration.Duration

implicit val system = ActorSystem()
implicit val materializer = ActorMaterializer()

val loop = 
  GraphDSL.create() { implicit b =>
    import GraphDSL.Implicits._

    val source = b.add(Flow[Int])
    val merge  = b.add(Merge[Int](2))
    val fetch  = b.add(itemFlow) 
    val bcast  = b.add(Broadcast[Item](2))

    val kids   = b.add(Flow[Item].mapConcat(_.kids))
    val data   = b.add(Flow[Item].map(_.data))

    val buffer = Flow[Int].buffer(100, OverflowStrategy.dropHead)

    source ~> merge ~> fetch           ~> bcast ~> data
              merge <~ buffer <~ kids  <~ bcast

    FlowShape(source.in, data.out)
  }

val flow = Flow.fromGraph(loop)


Await.result(
  Source.single(0).via(flow).runWith(Sink.foreach(println)),
  Duration.Inf
)

system.terminate()

However, since I'm using a flow with a buffer, the Stream will never complete.

Completes when upstream completes and buffered elements have been drained

Flow.buffer

I read the Graph cycles, liveness, and deadlocks section multiple times and I'm still struggling to find an answer.

This would create a live lock:

import java.util.concurrent.atomic.AtomicInteger

def unfold[S, E](seed: S, flow: Flow[S, E, NotUsed])(loop: E => List[S]): Source[E, NotUsed] = {
  // keep track of how many element flows, 
  val remaining = new AtomicInteger(1) // 1 = seed

  // should be > max loop(x)
  val bufferSize = 10000

  val (ref, publisher) =
    Source.actorRef[S](bufferSize, OverflowStrategy.fail)
      .toMat(Sink.asPublisher(true))(Keep.both)
      .run()

  ref ! seed

  Source.fromPublisher(publisher)
    .via(flow)
    .map{x =>
      loop(x).foreach{ c =>
        remaining.incrementAndGet()
        ref ! c
      }
      x
    }
    .takeWhile(_ => remaining.decrementAndGet > 0)
}

EDIT: I added a git repo to test your solution https://github.com/MasseGuillaume/source-unfold


Solution

  • I solved this problem by writing my own GraphStage.

    import akka.NotUsed
    import akka.stream._
    import akka.stream.scaladsl._
    import akka.stream.stage.{GraphStage, GraphStageLogic, OutHandler}
    
    import scala.concurrent.ExecutionContext
    
    import scala.collection.mutable
    import scala.util.{Success, Failure, Try}
    
    import scala.collection.mutable
    
    def unfoldTree[S, E](seeds: List[S], 
                         flow: Flow[S, E, NotUsed],
                         loop: E => List[S],
                         bufferSize: Int)(implicit ec: ExecutionContext): Source[E, NotUsed] = {
      Source.fromGraph(new UnfoldSource(seeds, flow, loop, bufferSize))
    }
    
    object UnfoldSource {
      implicit class MutableQueueExtensions[A](private val self: mutable.Queue[A]) extends AnyVal {
        def dequeueN(n: Int): List[A] = {
          val b = List.newBuilder[A]
          var i = 0
          while (i < n) {
            val e = self.dequeue
            b += e
            i += 1
          }
          b.result()
        }
      }
    }
    
    class UnfoldSource[S, E](seeds: List[S],
                             flow: Flow[S, E, NotUsed],
                             loop: E => List[S],
                             bufferSize: Int)(implicit ec: ExecutionContext) extends GraphStage[SourceShape[E]] {
    
      val out: Outlet[E] = Outlet("UnfoldSource.out")
      override val shape: SourceShape[E] = SourceShape(out)
    
      override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with OutHandler {  
        // Nodes to expand
        val frontier = mutable.Queue[S]()
        frontier ++= seeds
    
        // Nodes expanded
        val buffer = mutable.Queue[E]()
    
        // Using the flow to fetch more data
        var inFlight = false
    
        // Sink pulled but the buffer was empty
        var downstreamWaiting = false
    
        def isBufferFull() = buffer.size >= bufferSize
    
        def fillBuffer(): Unit = {
          val batchSize = Math.min(bufferSize - buffer.size, frontier.size)
          val batch = frontier.dequeueN(batchSize)
          inFlight = true
    
          val toProcess =
            Source(batch)
              .via(flow)
              .runWith(Sink.seq)(materializer)
    
          val callback = getAsyncCallback[Try[Seq[E]]]{
            case Failure(ex) => {
              fail(out, ex)
            }
            case Success(es) => {
              val got = es.size
              inFlight = false
              es.foreach{ e =>
                buffer += e
                frontier ++= loop(e)
              }
              if (downstreamWaiting && buffer.nonEmpty) {
                val e = buffer.dequeue
                downstreamWaiting = false
                sendOne(e)
              } else {
                checkCompletion()
              }
              ()
            }
          }
    
          toProcess.onComplete(callback.invoke)
        }
        override def preStart(): Unit = {
          checkCompletion()
        }
    
        def checkCompletion(): Unit = {
          if (!inFlight && buffer.isEmpty && frontier.isEmpty) {
            completeStage()
          }
        } 
    
        def sendOne(e: E): Unit = {
          push(out, e)
          checkCompletion()
        }
    
        def onPull(): Unit = {
          if (buffer.nonEmpty) {
            sendOne(buffer.dequeue)
          } else {
            downstreamWaiting = true
          }
    
          if (!isBufferFull && frontier.nonEmpty) {
            fillBuffer()
          }
        }
    
        setHandler(out, this)
      }
    }