Search code examples
scalasequencetraversal

Chain Scala Futures when processing a Seq of objects?


import scala.concurrent.duration.Duration
import scala.concurrent.duration.Duration._
import scala.concurrent.{Await, Future}
import scala.concurrent.Future._
import scala.concurrent.ExecutionContext.Implicits.global

object TestClass {

  final case class Record(id: String)

  final case class RecordDetail(id: String)

  final case class UploadResult(result: String)

  val ids: Seq[String] = Seq("a", "b", "c", "d")

  def fetch(id: String): Future[Option[Record]] = Future {
    Thread sleep 100
    if (id != "b" && id != "d") {
      Some(Record(id))
    } else None
  }

  def fetchRecordDetail(record: Record): Future[RecordDetail] = Future {
    Thread sleep 100
    RecordDetail(record.id + "_detail")
  }

  def upload(recordDetail: RecordDetail): Future[UploadResult] = Future {
    Thread sleep 100
    UploadResult(recordDetail.id + "_uploaded")
  }

  def notifyUploaded(results: Seq[UploadResult]): Unit = println("notified " + results)

  def main(args: Array[String]): Unit = {

    //for each id from ids, call fetch method and if record exists call fetchRecordDetail 
    //and after that upload RecordDetail, collect all UploadResults into seq
    //and call notifyUploaded with that seq and await result, you should see "notified ...." in console


    // In the following line of code how do I pass result of fetch to fetchRecordDetail function
    val result = Future.traverse(ids)(x => Future(fetch(x)))
    // val result: Future[Unit] = ???

    Await.ready(result, Duration.Inf)
  }

}

My problem is that I don't know what code to put in the main to make it work as written in the comments. To sum up, I have an ids:Seq[String] and I want each id to go through asynchronous methods fetch, fetchRecordDetail, upload, and finally the whole Seq to come to notifyUploaded.


Solution

  • I think that the simplest way to do it is :

      def main(args: Array[String]): Unit = {
    
        //for each id from ids, call fetch method and if record exists call fetchRecordDetail
        //and after that upload RecordDetail, collect all UploadResults into seq
        //and call notifyUploaded with that seq and await result, you should see "notified ...." in console
    
        def runWithOption[A, B](f: A => Future[B], oa: Option[A]): Future[Option[B]] = oa match {
          case Some(a) => f(a).map(b => Some(b))
          case None => Future.successful(None)
        }
    
        val ids: Seq[String] = Seq("a", "b", "c", "d")
    
        val resultSeq: Seq[Future[Option[UploadResult]]] = ids.map(id => {
          for (or: Option[Record] <- fetch(id);
               ord: Option[RecordDetail] <- runWithOption(fetchRecordDetail, or);
               our: Option[UploadResult] <- runWithOption(upload, ord)
          ) yield our
        })
    
        val filteredResult: Future[Seq[UploadResult]] = Future.sequence(resultSeq).map(s => s.collect({ case Some(ur) => ur }))
        val result: Future[Seq[UploadResult]] = filteredResult.andThen({ case Success(s) => notifyUploaded(s) })
    
        Await.ready(result, Duration.Inf)
      }
    

    The idea is that you first get a Seq[Future[_]] that you map through all the methods (here it is done using for-comprehension). Here is an important trick is to actually pass Seq[Future[Option[_]]]. Passing Option[_] through the whole chain via runWithOption helper method simplifies code a lot without a need to block until the very last stage.

    Then you convert Seq[Future[_]] into a Future[Seq[_]] and filter out results for those ids that failed at the fetch stage. And finally you apply notifyUploaded.

    P.S. Note that there is no error handling in this code whatsoever and it is not clear how you expect it to behave in case of errors at different stages.