Search code examples
scalascala-catscats-effect

Confused about cats-effect Async.memoize


I'm fairly new to cats-effect, but I think I am getting a handle on it. But I have come to a situation where I want to memoize the result of an IO, and it's not doing what I expect.

The function I want to memoize transforms String => String, but the transformation requires a network call, so it is implemented as a function String => IO[String]. In a non-IO world, I'd simply save the result of the call, but the defining function doesn't actually have access to it, as it doesn't execute until later. And if I save the constructed IO[String], it won't actually help, as that IO would repeat the network call every time it's used. So instead, I try to use Async.memoize, which has the following documentation:

Lazily memoizes f. For every time the returned F[F[A]] is bound, the effect f will be performed at most once (when the inner F[A] is bound the first time).

What I expect from memoize is a function that only ever executes once for a given input, AND where the contents of the returned IO are only ever evaluated once; in other words, I expect the resulting IO to act as if it were IO.pure(result), except the first time. But that's not what seems to be happening. Instead, I find that while the called function itself only executes once, the contents of the IO are still evaluated every time - exactly as would occur if I tried to naively save and reuse the IO.

I constructed an example to show the problem:

def plus1(num: Int): IO[Int] = {
      println("foo")
      IO(println("bar")) *> IO(num + 1)
    }
    var fooMap = Map[Int, IO[IO[Int]]]()
    def mplus1(num: Int): IO[Int] = {
      val check = fooMap.get(num)
      val res = check.getOrElse {
        val plus = Async.memoize(plus1(num))
        fooMap = fooMap + ((num, plus))
        plus
      }
      res.flatten
    }

    println("start")
    val call1 = mplus1(2)
    val call2 = mplus1(2)
    val result = (call1 *> call2).unsafeRunSync()
    println(result)
    println(fooMap.toString)
    println("finish")

The output of this program is:

start
foo
bar
bar
3
Map(2 -> <function1>)
finish

Although the plus1 function itself only executes once (one "foo" printed), the output "bar" contained within the IO is printed twice, when I expect it to also print only once. (I have also tried flattening the IO returned by Async.memoize before storing it in the map, but that doesn't do much).


Solution

  • Consider following examples

    Given the following helper methods

    def plus1(num: Int): IO[IO[Int]] = {
      IO(IO(println("plus1")) *> IO(num + 1))
    }
    
    def mPlus1(num: Int): IO[IO[Int]] = {
      Async.memoize(plus1(num).flatten)
    }
    

    Let's build a program that evaluates plus1(1) twice.

    val program1 = for {
      io <- plus1(1)
      _ <- io
      _ <- io
    } yield {}
    program1.unsafeRunSync()
    

    This produces the expected output of printing plus1 twice.

    If you do the same but instead using the mPlus1 method

    val program2 = for {
      io <- mPlus1(1)
      _ <- io
      _ <- io
    } yield {}
    program2.unsafeRunSync()
    

    It will print plus1 just once confirming that memoization is working.

    The trick with the memoization is that it should be evaluated only once to have the desired effect. Consider now the following program that highlights it.

    val memIo = mPlus1(1)
    val program3 = for {
      io1 <- memIo
      io2 <- memIo
      _ <- io1
      _ <- io2
    } yield {}
    program3.unsafeRunSync()
    

    And it outputs plus1 twice as io1 and io2 are memoized separately.

    As for your example, the foo is printed once because you're using a map and update the value when it's not found and this happens only once. The bar is printed every time when IO is evaluated as you lose the memoization effect by calling res.flatten.