Search code examples
scalaapache-sparkfunctional-programmingdatabricksblockingqueue

Scala Thread Pool - Invoking API's Concurrently



I have a use-case in databricks where an API call has to me made on a dataset of URL's. The dataset has around 100K records. The max allowed concurrency is 3.
I did the implementation in Scala and ran in databricks notebook. Apart from the one element pending in queue, i feel some thing is missing here.
Is the Blocking Queue and Thread Pool the right way to tackle this problem.

In the code below I have modified and instead of reading from dataset I am sampling on a Seq. Any help/thought will be much appreciated.

 
import java.time.LocalDateTime
import java.util.concurrent.{ArrayBlockingQueue,BlockingQueue}
import java.util.concurrent.Executors
import java.util.concurrent.TimeUnit; 

var inpQueue:BlockingQueue[(Int, String)] = new ArrayBlockingQueue[(Int, String)](1)

val inpDS = Seq((1,"https://google.com/2X6barD"), (2,"https://google.com/3d9vCgW"), (3,"https://google.com/2M02Xz0"), (4,"https://google.com/2XOu2uL"), (5,"https://google.com/2AfBWF0"), (6,"https://google.com/36AEKsw"), (7,"https://google.com/3enBxz7"), (8,"https://google.com/36ABq0x"), (9,"https://google.com/2XBjmiF"), (10,"https://google.com/36Emlen"))


val pool = Executors.newFixedThreadPool(3) 
var i = 0
inpDS.foreach{
  ix => {

    inpQueue.put(ix)
    val t = new ConsumerAPIThread()
    t.setName("MyThread-"+i+" ")
    pool.execute(t)

  }
   i = i+1
}

println("Final Queue Size = " +inpQueue.size+"\n")


class ConsumerAPIThread() extends Thread  
{ 
  var name =""

    override def run() 
    { 
        val urlDetail =  inpQueue.take()
        print(this.getName()+" "+ Thread.currentThread().getName() + " popped "+urlDetail+" Queue Size "+inpQueue.size+" \n") 
      triggerAPI((urlDetail._1, urlDetail._2))
    } 

    def triggerAPI(params:(Int,String)){

    try{
      val result = scala.io.Source.fromURL(params._2)
      println("" +result)
    }catch{
     case ex:Exception  => {

       println("Exception caught")
       }

    }

  }
   def ConsumerAPIThread(s:String) 
    { 
        name = s; 
    } 
}

Solution

  • So, you have two requirements: the functional one is that you want to process asynchronously the items in a list, the non-functional one is that you want to not process more than three items at once.

    Regarding the latter, the nice thing is that, as you already have shown in your question, Java natively exposes a nicely packaged Executor that runs task on a thread pool with a fixed size, elegantly allowing you to cap the concurrency level if you work with threads.

    Moving to the functional requirement, Scala helps by having something that does precisely that as part of its standard API. In particular it uses scala.concurrent.Future, so in order to use it we'll have to reframe triggerAPI in terms of Future. The content of the function is not particularly relevant, so we'll mostly focus on its (revised) signature for now:

    import scala.concurrent.Future
    import scala.concurrent.ExecutionContext
    
    def triggerAPI(params: (Int, String))(implicit ec: ExecutionContext): Future[Unit] =
      Future {
        // some code that takes some time to run...
      }
    

    Notice that now triggerAPI returns a Future. A Future can be thought as a read-handle to something that is going to be eventually computed. In particular, this is a Future[Unit], where Unit stands for "we don't particularly care about the output of this function, but mostly about its side effects".

    Furthermore, notice that the method now takes an implicit parameter, namely an ExecutionContext. The ExecutionContext is used to provide Futures with some form of environment where the computation happens. Scala has an API to create an ExecutionContext from a java.util.concurrent.ExecutorService, so this will come in handy to run our computation on the fixed thread pool, running no more than three callbacks at any given time.

    Before moving forward, if you have questions about Futures, ExecutionContexts and implicit parameters, the Scala documentation is your best source of knowledge (here are a couple of pointers: 1, 2).

    Now that we have the new triggerAPI method, we can use Future.traverse (here is the documentation for Scala 2.12 -- the latest version at the time of writing is 2.13 but to the best of my knowledge Spark users are stuck on 2.12 for the time being).

    The tl;dr of Future.traverse is that it takes some form of container and a function that takes the items in that container and returns a Future of something else. The function will be applied to each item in the container and the result will be a Future of the container of the results. In your case: the container is a List, the items are (Int, String) and the something else you return is a Unit.

    This means that you can simply call it like this:

    Future.traverse(inpDS)(triggerAPI)
    

    And triggerAPI will be applied to each item in inpDS.

    By making sure that the execution context backed by the thread pool is in the implicit scope when calling Future.traverse, the items will be processed with the desired thread pool.

    The result of the call is Future[List[Unit]], which is not very interesting and can simply be discarded (as you are only interested in the side effects).

    That was a lot of talk, if you want to play around with the code I described you can do so here on Scastie.

    For reference, this is the whole implementation:

    import java.util.concurrent.{ExecutorService, Executors}
    
    import scala.concurrent.duration.DurationLong
    import scala.concurrent.Future
    import scala.concurrent.{ExecutionContext, ExecutionContextExecutorService}
    import scala.util.control.NonFatal
    import scala.util.{Failure, Success, Try}
    
    val datasets = List(
      (1, "https://google.com/2X6barD"),
      (2, "https://google.com/3d9vCgW"),
      (3, "https://google.com/2M02Xz0"),
      (4, "https://google.com/2XOu2uL"),
      (5, "https://google.com/2AfBWF0"),
      (6, "https://google.com/36AEKsw"),
      (7, "https://google.com/3enBxz7"),
      (8, "https://google.com/36ABq0x"),
      (9, "https://google.com/2XBjmiF")
    )
    
    val executor: ExecutorService = Executors.newFixedThreadPool(3)
    implicit val executionContext: ExecutionContextExecutorService = ExecutionContext.fromExecutorService(executor)
    
    def triggerAPI(params: (Int, String))(implicit ec: ExecutionContext): Future[Unit] =
      Future {
        val (index, _) = params
        println(s"+ started processing $index")
        val start = System.nanoTime() / 1000000
        Iterator.from(0).map(_ + 1).drop(100000000).take(1).toList.head // a noticeably slow operation
        val end = System.nanoTime() / 1000000
        val duration = (end - start).millis
        println(s"- finished processing $index after $duration")
      }
    
    Future.traverse(datasets)(triggerAPI).onComplete {
      case result =>
        println("* processing is over, shutting down the executor")
        executionContext.shutdown()
    }