Search code examples
multithreadingscalaconcurrencyfunctional-programmingjava.util.concurrent

Custom synchronization using java.util.concurrent with cats.effect


I have a requirements of pretty custom non-trivial synchronization which can be implemented with a fair ReentrantLock and Phaser. It does not seem to be possible (without a non-trivial customization) to implement on fs2 and cats.effect.

Since it's required to wrap all blocking operation into a Blocker here is the code:

private val l: ReentrantLock = new ReentrantLock(true)
private val c: Condition = l.newCondition
private val b: Blocker = //...

//F is declared on the class level
def lockedMutex(conditionPredicate: Int => Boolean): F[Unit] = blocker.blockOn {
  Sync[F].delay(l.lock()).bracket(_ => Sync[F].delay{
    while(!conditionPredicate(2)){
      c.await()
    }
  })(_ => Sync[F].delay(l.unlock()))
}

QUESTION: Is it guaranteed that the code containing c.await() will be executed in the same Thread which acquires/releases the ReentrantLock?

This is a crucial part since if it's not IllegalMonitorStateException will be thrown.


Solution

  • You really do not need to worry about threads when using something like cats-effect, rather you can describe your problem on a higher level.

    This should get the same behavior you want, it will be running high-priority jobs until there isn't more to then pick low-priority jobs. After finishing a low-priority job each fiber will first check if there are more high-priority jobs before trying to pick again a low-priority one:

    import cats.effect.Async
    import cats.effect.std.Queue
    import cats.effect.syntax.all._
    import cats.syntax.all._
    
    import scala.concurrent.ExecutionContext
    
    object HighLowPriorityRunner {
      final case class Config[F[_]](
          highPriorityJobs: Queue[F, F[Unit]],
          lowPriorityJobs: Queue[F, F[Unit]],
          customEC: Option[ExecutionContext]
      )
    
      def apply[F[_]](config: Config[F])
                     (implicit F: Async[F]): F[Unit] = {
        val processOneJob =
          config.highPriorityJobs.tryTake.flatMap {
            case Some(hpJob) => hpJob
            case None => config.lowPriorityJobs.tryTake.flatMap {
              case Some(lpJob) => lpJob
              case None => F.unit
            }
          }
    
        val loop: F[Unit] = processOneJob.start.foreverM
    
        config.customEC.fold(ifEmpty = loop)(ec => loop.evalOn(ec))
      }
    }
    

    You can use the customEC to provide your own ExecutionContext to control the number of real threads that are running your fibers under the hood.

    The code can be used like this:

    import cats.effect.{Async, IO, IOApp, Resource}
    import cats.effect.std.Queue
    import cats.effect.syntax.all._
    import cats.syntax.all._
    
    import java.util.concurrent.Executors
    import scala.concurrent.ExecutionContext
    import scala.concurrent.duration._
    
    object Main extends IOApp.Simple {
      override final val run: IO[Unit] =
        Resource.make(IO(Executors.newFixedThreadPool(2)))(ec => IO.blocking(ec.shutdown())).use { ec =>
          Program[IO](ExecutionContext.fromExecutor(ec))
        }
    }
    
    object Program {
      private def createJob[F[_]](id: Int)(implicit F: Async[F]): F[Unit] =
        F.delay(println(s"Starting job ${id} on thread ${Thread.currentThread.getName}")) *>
        F.delay(Thread.sleep(1.second.toMillis)) *> // Blocks the Fiber! - Only for testing, use F.sleep on real code.
        F.delay(println(s"Finished job ${id}!"))
    
      def apply[F[_]](customEC: ExecutionContext)(implicit F: Async[F]): F[Unit] = for {
        highPriorityJobs <- Queue.unbounded[F, F[Unit]]
        lowPriorityJobs <- Queue.unbounded[F, F[Unit]]
        runnerFiber <- HighLowPriorityRunner(HighLowPriorityRunner.Config(
          highPriorityJobs,
          lowPriorityJobs,
          Some(customEC)
        )).start
        _ <- List.range(0, 10).traverse_(id => highPriorityJobs.offer(createJob(id)))
        _ <- List.range(10, 15).traverse_(id => lowPriorityJobs.offer(createJob(id)))
        _ <- F.sleep(5.seconds)
        _ <- List.range(15, 20).traverse_(id => highPriorityJobs.offer(createJob(id)))
        _ <- runnerFiber.join.void
      } yield ()
    }
    

    Which should produce an output like this:

    Starting job 0 on thread pool-1-thread-1
    Starting job 1 on thread pool-1-thread-2
    Finished job 0!
    Finished job 1!
    Starting job 2 on thread pool-1-thread-1
    Starting job 3 on thread pool-1-thread-2
    Finished job 2!
    Finished job 3!
    Starting job 4 on thread pool-1-thread-1
    Starting job 5 on thread pool-1-thread-2
    Finished job 4!
    Finished job 5!
    Starting job 6 on thread pool-1-thread-1
    Starting job 7 on thread pool-1-thread-2
    Finished job 6!
    Finished job 7!
    Starting job 8 on thread pool-1-thread-1
    Starting job 9 on thread pool-1-thread-2
    Finished job 8!
    Finished job 9!
    Starting job 10 on thread pool-1-thread-1
    Starting job 11 on thread pool-1-thread-2
    Finished job 10!
    Finished job 11!
    Starting job 15 on thread pool-1-thread-1
    Starting job 16 on thread pool-1-thread-2
    Finished job 15!
    Finished job 16!
    Starting job 17 on thread pool-1-thread-1
    Starting job 18 on thread pool-1-thread-2
    Finished job 17!
    Finished job 18!
    Starting job 19 on thread pool-1-thread-1
    Starting job 12 on thread pool-1-thread-2
    Finished job 19!
    Starting job 13 on thread pool-1-thread-1
    Finished job 12!
    Starting job 14 on thread pool-1-thread-2
    Finished job 13!
    Finished job 14!
    

    Thanks to Gavin Bisesi (@Daenyth) for refining my original idea into this!


    Full code available here.