Search code examples
multithreadingscalaconcurrencyscala-catsmonix

stop all async Task when they fails over threshold?


I'm using Monix Task for async control.

scenario

  1. tasks are executed in parallel
  2. if failure occurs over X times
  3. stop all tasks that are not yet in complete status (as quick as better)

my solution

I come up the ideas that race between 1. result and 2. error counter, and cancel the loser.
Via Task.race if the error-counter get to threshold first, then the tasks would be canceled by Task.race.

experiment

on Ammonite REPL

{
  import $ivy.`io.monix::monix:3.1.0`
  import monix.eval.Task
  import monix.execution.atomic.Atomic
  import scala.concurrent.duration._
  import monix.execution.Scheduler
  //import monix.execution.Scheduler.Implicits.global
  implicit val s = Scheduler.fixedPool("race", 2) // pool size

  val taskSize = 100
  val errCounter = Atomic(0)
  val threshold = 3

  val tasks = (1 to taskSize).map(_ => Task.sleep(100.millis).map(_ => errCounter.increment()))
  val guard = Task(f"stop because too many error: ${errCounter.get()}")
    .restartUntil(_ => errCounter.get() >= threshold)

  val race = Task
    .race(guard, Task.gather(tasks))
    .runToFuture
    .onComplete { case x => println(x); println(f"completed task: ${errCounter.get()}") }
}

issue

The outcome is depends on thread pool size !?

For pool size 1
the outcome is almost always a task success i.e. no stop.

Success(Right(.........))
completed task: 100 // all task success !

For pool size 2
it is very un-deterministic between success and failure and the cancelling is not accurate. for example:

Success(Left(stop because too many error: 1))
completed task: 98

the canceling is as late as 98 tasks has completed.
the error count is weird small to threshold.

The default global scheduler get this same outcome behavior.

For pool size 200
it is more deterministic and the stopping is earlier thus more accurate in sense that less task was completed.

Success(Left(stop because too many error: 2))
completed task: 8

the larger of the pool size the better.


If I change Task.gather to Task.sequence execution, all issues disappeared!


What is the cause for this dependency on pool size ? How to improve it or is there better alternative for stopping tasks once too many error occurs ?


Solution

  • What you're seeing is likely an effect of the monix scheduler and how it aims for fairness. It's a fairly complex topic but the documentation and scaladocs are excellent (see: https://monix.io/docs/3x/execution/scheduler.html#execution-model)

    When you have only one thread (or few) it takes a while until the "guard" Task gets another turn to check. With Task.gather you start 100 tasks at once, so the scheduler is very busy and the "guard" cannot check again until the other tasks are already done. If you have one thread per task the scheduler cannot guarantee fairness and therefore the "guard" unfairly checks much more frequently and can finish sooner.

    If you use Task.sequence those 100 tasks are executed sequentially, which is why the "guard" task gets much more opportunities to finish as soon as needed. If you want to keep your code the way it is, you could use Task.gatherN(parallelism = 4) which will limit the parallelism and therefore allow your "guard" to check more often (a middleground between Task.sequence and Task.gather).

    It seems a bit like Go code to me (using Task.race like Go's select) and you're also using side-effects unconstrained which further complicates understanding what's going on. I've tried to rewrite your program in a way that's more idiomatic and for complicated concurrency I usually reach for streams like Observable:

    import cats.effect.concurrent.Ref
    import monix.eval.Task
    import monix.execution.Scheduler
    import monix.reactive.Observable
    
    import scala.concurrent.duration._
    
    object ErrorThresholdDemo extends App {
    
      //import monix.execution.Scheduler.Implicits.global
      implicit val s: Scheduler = Scheduler.fixedPool("race", 2) // pool size
    
      val taskSize  = 100
      val threshold = 30
    
      val program = for {
        errCounter <- Ref[Task].of(0)
    
        tasks = (1 to taskSize).map(n => Task.sleep(100.millis).flatMap(_ => errCounter.update(_ + (n % 2))))
    
        tasksFinishedCount <- Observable
          .fromIterable(tasks)
          .mapParallelUnordered(parallelism = 4) { task =>
            task
          }
          .takeUntilEval(errCounter.get.restartUntil(_ >= threshold))
          .map(_ => 1)
          .sumL
    
        errorCount <- errCounter.get
        _          <- Task(println(f"completed tasks: $tasksFinishedCount, errors: $errorCount"))
      } yield ()
    
      program.runSyncUnsafe()
    }
    

    As you can see I no longer use global mutable side-effects but instead Ref which interally also uses Atomic but provides a functional api which we can use with Task. For demonstration purposes I also changed the threshold to 30 and only every other task will "error". So the expected output is always around completed tasks: 60, errors: 30 no matter the thread-pool size.

    I'm still using polling with errCounter.get.restartUntil(_ >= threshold) which might burn a bit too much CPU for my taste but it's close to your original idea and works well.

    Usually I don't create a list of tasks up front but instead throw the inputs into the Observable and create the tasks inside of .mapParallelUnordered. This code keeps your list which is why there is no real mapping involved (it already contains tasks).

    You can choose your desired parallelism much like with Task.gatherN which is pretty nice imo.

    Let me know if anything is still unclear :)