Search code examples
scalascala-catscats-effectfs2

fs2 - Sharing a Ref with 2 streams


I'm trying to share a Ref[F, A] between 2 concurrent streams. Below is a simplified example of the actual scenario.

  class Container[F[_]](implicit F: Sync[F]) {
    private val counter = Ref[F].of(0)

    def incrementBy2 = counter.flatMap(c => c.update(i => i + 2))

    def printCounter = counter.flatMap(c => c.get.flatMap(i => F.delay(println(i))))
  }

In the main function,

object MyApp extends IOApp {

  def run(args: List[String]): IO[ExitCode] = {
    val s = for {
      container <- Ref[IO].of(new Container[IO]())
    } yield {
      val incrementBy2 = Stream.repeatEval(
          container.get
            .flatTap(c => c.incrementBy2)
            .flatMap(c => container.update(_ => c))
        )
        .metered(2.second)
        .interruptScope

      val printStream = Stream.repeatEval(
          container.get
            .flatMap(_.printCounter)
        )
        .metered(1.seconds)

      incrementBy2.concurrently(printStream)
    }
    Stream.eval(s)
      .flatten
      .compile
      .drain
      .as(ExitCode.Success)
  }
}

The updates made by the incrementBy2 are not visible in printStream. How can I fix this? I would appreciate any help to understand the mistake in this code.

Thanks


Solution

  • Your code is broken since the class definition, you are not even updating the same Ref

    Remember that the point of IO is to be a description of a computation, so Ref[F].of(0) returns a program that when evaluated will create a new Ref, calling multiple flatMaps on it will result in multiple Refs being created.

    Also, your is not doing tagless final in the right way (and some may argue that even the right way is not worth it: https://alexn.org/blog/2022/04/18/scala-oop-design-sample/)

    This is how I would write your code:

    trait Counter {
      def incrementBy2: IO[Unit]
      def printCounter: IO[Unit]
    }
    object Counter {
      val inMemory: IO[Counter] =
        IO.ref(0).map { ref =>
          new Counter {
            override final val incrementBy2: IO[Unit] =
              ref.update(c => c + 2)
            
            override final val printCounter: IO[Unit] =
              ref.get.flatMap(IO.println)
          }
        }
    }
    
    object Program {
      def run(counter: Counter): Stream[IO, Unit] =
        Stream
          .repeatEval(counter.printCounter)
          .metered(1.second)
          .concurrently(
            Stream.repeatEval(counter.incrementBy2).metered(2.seconds)
          ).interruptAfter(10.seconds)
    }
    
    object Main extends IOApp.Simple {
      override final val run: IO[Unit] =
        Stream
          .eval(Counter.inMemory)
          .flatMap(Program.run)
          .compile
          .drain
    }
    

    PS: I would actually not have printCounter but getCounter and do the printing in the Program


    You can see the code running here.