Search code examples
scalaconcurrencyscala-catscats-effect

Race condition caused by mvar


Consider the following simple example:

implicit val timer: Timer[IO]     = IO.timer(ExecutionContext.global)
implicit val cs: ContextShift[IO] = IO.contextShift(ExecutionContext.global)

val mvarF = MVar.of[IO, mutable.Map[Int, Int]](mutable.Map.empty)

mvarF.flatMap(mvar =>

  mvar.take.bracket(st => {
    IO(st.put(1, 1)) >> (IO.sleep(2.seconds) >> IO(st.clear())).start
  })(mvar.put)
    >>
  mvar.take.bracket(st =>
    IO(println(s"Size before sleep ${st.size}")) >> IO.sleep(2.seconds) >> IO(println(s"Size after sleep ${st.size}"))
  )(mvar.put)

).unsafeRunSync()

It prints:

Size before sleep 1
Size after sleep 0

In this example the job scheduled under fiber modifies the object under mvar which is acquired by another job.

This extremely unsafe. Is there a way to prohibit such usage?


Solution

  • As pointed by Jasper, your main problem (for this specific code example) is that you are releasing the bracket after starting a new fiber in your use by calling IO(st.put(1, 1)) >> (IO.sleep(2.seconds) >> IO(st.clear())).start. So your use is actually an mutable.Map[Int, Int] => IO[Fiber[IO, Unit]].

    You just have to remove that start and you will have intended behaviour (your use will be mutable.Map[Int, Int] => IO[Unit] and bracket will not be released unless your use IO completes.). This means that the map will alreay be empty for both print actions.

    mvarF.flatMap(mvar =>
      mvar.take.bracket(st => {
        IO(st.put(1, 1)) >>
          IO.sleep(2.seconds) >>
            IO(st.clear())
      })(mvar.put)
        >>
        mvar.take.bracket(st =>
          IO(println(s"Size before sleep ${st.size}")) >>
            IO.sleep(2.seconds) >>
              IO(println(s"Size after sleep ${st.size}"))
        )(mvar.put)
    
    ).unsafeRunSync()
    
    Size before sleep 0
    Size after sleep 0
    

    But this is actually just a conincidence for this specific code example (IO's are being chained with flatMap which means we are telling the runtime to sequentially perform these IO's).

    MVar provides you control over re-assignment of the variable, but your are not doing any re-assignment at all. Hence, this code is not even using any capbilities of MVar, its just sitting there as an spectator.

    So, the usage of MVar in this way will have ZERO impact on thread saftey of your code.

    mvarF.flatMap(mvar =>
      mvar.take.bracket(st =>
        IO(println(s"Size before first sleep - ${st.size}")) >> IO.sleep(2.seconds) >> IO(println(s"Size after first sleep - ${st.size}"))
      )(mvar.put)
    ).unsafeRunAsyncAndForget()
    
    mvarF.flatMap(mvar =>
      mvar.take.bracket(st => {
        IO(st.put(1, 1)) >> IO.sleep(2.seconds) >> IO(st.clear())
      })(mvar.put)
    ).unsafeRunAsyncAndForget()
    
    mvarF.flatMap(mvar =>
        mvar.take.bracket(st =>
          IO(println(s"Size before second sleep - ${st.size}")) >> IO.sleep(2.seconds) >> IO(println(s"Size after second sleep - ${st.size}"))
        )(mvar.put)
    ).unsafeRunAsyncAndForget()
    
    Size before first sleep - 0
    Size before second sleep - 1
    Size after first sleep - 0
    Size after second sleep - 0
    

    You can use Semaphore to get a race free scope.

    class IOWithSemaphore[A](
        private val a: A, 
        private val semaphore: Semaphore[IO]
      )(
        implicit
        F: Concurrent[IO],
        T: Timer[IO]) {
    
      def unitUse(use: A => IO[Unit]): IO[Unit] =
        for {
          _ <- semaphore.acquire
          _ <- use(a)
          _ <- semaphore.release
        } yield ()
    
    }
    
    val map = mutable.Map.empty[Int, Int]
    
    Semaphore[IO](1).map(semaphore => {
      val mapIOWithSemaphore = new IOWithSemaphore[mutable.Map[Int, Int]](map, semaphore)
    
      // using unsafeRunAsync to emulate the parallel usage
    
      mapIOWithSemaphore.unitUse(map =>
        IO(println(s"Size before first sleep - ${map.size}")) >> IO.sleep(2.seconds) >> IO(println(s"Size after first sleep - ${map.size}"))
      ).unsafeRunAsyncAndForget()
    
      mapIOWithSemaphore.unitUse(map =>
        IO(println(s"MUTATION BEGIN")) >> IO(map.put(1, 1)) >> IO.sleep(2.seconds) >> IO(map.clear()) >> IO(println(s"MUTATION END"))
      ).unsafeRunAsyncAndForget()
    
      mapIOWithSemaphore.unitUse(map =>
        IO(println(s"Size before second sleep - ${map.size}")) >> IO.sleep(2.seconds) >> IO(println(s"Size after second sleep - ${map.size}"))
      ).unsafeRunAsyncAndForget()
    
    }).unsafeRunAsyncAndForget()
    
    Await.result(Promise[Unit].future, Duration.Inf)
    
    Size before first sleep - 0
    Size after first sleep - 0
    MUTATION BEGIN
    MUTATION END
    Size before second sleep - 0
    Size after second sleep - 0