Search code examples
scalaamazon-web-servicesaws-sdkamazon-snsaws-java-sdk

A better way to handle asynchronous calls when using the AWS Java SDK V2 in Scala?


Background

A few days ago, version 2.1 of the AWS Java SDK was officially released. One of the main selling points is how it handles asynchronous calls vs the previous version of the SDK.

I decided to run some experiments using Scala and the new SDK and had a bit of a hard time trying to come up with an idiomatic way to deal with the Futures returned by the SDK.

The Question

Is there a way I can do this better, more sucinctly, and with less boilerplate code?

Objective

Deal with AWS SDK for Java V2 using Scala and be able to handle success and failures in an idiomatic way.

The Experiment

Create an Async SNS Client and submit messages 500 asynchronously:

Experiment 1 - Use the CompletableFuture Returned by the SDK

  (0 until 500).map { i =>
    val future = client.publish(PublishRequest.builder().topicArn(arn).message(messageJava + i.toString).build())
    future.whenComplete((response, ex) => {
      val responseOption = Option(response) // Response can be null
      responseOption match {
        case Some(r) => println(r.messageId())
        case None => println(s"There was an error ${ex.getMessage}")
      }
    })
  }.foreach(future => future.join())

Here, I create a unique request and publish it. The whenComplete function turns the response into an option as this value can be null. This is ugly because the means of dealing with success/failure are bound on checking for null in the response.

Experiment 2 - Get the result inside a Scala Future

(0 until 500).map { i =>
    val jf = client.publish(PublishRequest.builder().topicArn(arn).message(messageScala + i.toString).build())
    val sf: Future[PublishResponse] = Future { jf.get }
    sf.onComplete {
      case Success(response) => print(response.messageId)
      case Failure(ex) => println(s"There was an error ${ex.getMessage}")
    }
    sf
  }.foreach(Await.result(_, 5000.millis))

Here I use the .get() method on the CompletableFuture that way I can just deal with the Scala Future.

Experiment 3 - Use the Scala - Java8 - Compat library convert the CompletableFuture to a Future

(0 until 500).map { i =>
    val f = client.publish(PublishRequest.builder().topicArn(arn).message(messageScala + i.toString).build()).toScala
    f.onComplete {
      case Success(response) =>
      case Failure(exception) => println(exception.getMessage)
    }
    f
  }.foreach(Await.result(_, 5000.millis))

This is by far my favorite implementation, except that I need to use a thrid party experimental library.

Observations

  • In general all these implementations performed roughly the same, with the future.join() being a tiny bit faster than the others.
  • The time it took for these functions to initialize the client and publish 500 messages was around 2 seconds
  • The sequential version of this code takes a little bit under 1 minute (55 seconds)
  • You can see the complete code here

Solution

  • You have mentioned that you are happy with converting completablefuture to scala.future, just that you do not like to take a dependency on scala-java8-compat.

    In this case, you can simply roll your own, and you only want java8 to scala:

    object CompletableFutureOps {                                                                                                                                        
    
      implicit class CompletableFutureToScala[T](cf: CompletableFuture[T]) {                                                                                             
        def asScala: Future[T] = {                                                                                                                                       
          val p = Promise[T]()                                                                                                                                           
          cf.whenCompleteAsync{ (result, ex) =>                                                                                                                          
            if (result == null) p failure ex                                                                                                                             
            else                p success result                                                                                                                         
          }                                                                                                                                                              
          p.future                                                                                                                                                       
        }                                                                                                                                                                
      }                                                                                                                                                                  
    }
    
    def showByExample: Unit = {
      import CompletableFutureOps._   
      (0 until 500).map { i =>                                                                                                                                                                                                                                                                                     
         val f = CompletableFuture.supplyAsync(() => i).asScala                                                                                                             
         f.onComplete {                                                                                                                                                     
           case Success(response)  => println("Success: " + response)                                                                                                        
           case Failure(exception) => println(exception.getMessage)                                                                                                         
         }                                                                                                                                                                  
         f                                                                                                                                                                  
      }.foreach(Await.result(_, 5000.millis))    
    }