Search code examples
scalafunctional-programmingtail-recursionscala-catsfs2

How to reason about stack safety in Scala Cats / fs2?


Here is a piece of code from the documentation for fs2. The function go is recursive. The question is how do we know if it is stack safe and how to reason if any function is stack safe?

import fs2._
// import fs2._

def tk[F[_],O](n: Long): Pipe[F,O,O] = {
  def go(s: Stream[F,O], n: Long): Pull[F,O,Unit] = {
    s.pull.uncons.flatMap {
      case Some((hd,tl)) =>
        hd.size match {
          case m if m <= n => Pull.output(hd) >> go(tl, n - m)
          case m => Pull.output(hd.take(n.toInt)) >> Pull.done
        }
      case None => Pull.done
    }
  }
  in => go(in,n).stream
}
// tk: [F[_], O](n: Long)fs2.Pipe[F,O,O]

Stream(1,2,3,4).through(tk(2)).toList
// res33: List[Int] = List(1, 2)

Would it also be stack safe if we call go from another method?

def tk[F[_],O](n: Long): Pipe[F,O,O] = {
  def go(s: Stream[F,O], n: Long): Pull[F,O,Unit] = {
    s.pull.uncons.flatMap {
      case Some((hd,tl)) =>
        hd.size match {
          case m if m <= n => otherMethod(...)
          case m => Pull.output(hd.take(n.toInt)) >> Pull.done
        }
      case None => Pull.done
    }
  }

  def otherMethod(...) = {
    Pull.output(hd) >> go(tl, n - m)
  }

  in => go(in,n).stream
}

Solution

  • My previous answer here gives some background information that might be useful. The basic idea is that some effect types have flatMap implementations that support stack-safe recursion directly—you can nest flatMap calls either explicitly or through recursion as deeply as you want and you won't overflow the stack.

    For some effect types it's not possible for flatMap to be stack-safe, because of the semantics of the effect. In other cases it may be possible to write a stack-safe flatMap, but the implementers might have decided not to because of performance or other considerations.

    Unfortunately there's no standard (or even conventional) way to know whether the flatMap for a given type is stack-safe. Cats does include a tailRecM operation that should provide stack-safe monadic recursion for any lawful monadic effect type, and sometimes looking at a tailRecM implementation that's known to be lawful can provide some hints about whether a flatMap is stack-safe. In the case of Pull it looks like this:

    def tailRecM[A, B](a: A)(f: A => Pull[F, O, Either[A, B]]) =
      f(a).flatMap {
        case Left(a)  => tailRecM(a)(f)
        case Right(b) => Pull.pure(b)
      }
    

    This tailRecM is just recursing through flatMap, and we know that Pull's Monad instance is lawful, which is pretty good evidence that Pull's flatMap is stack-safe. The one complicating factor here is that the instance for Pull has an ApplicativeError constraint on F that Pull's flatMap doesn't, but in this case that doesn't change anything.

    So the tk implementation here is stack-safe because flatMap on Pull is stack-safe, and we know that from looking at its tailRecM implementation. (If we dug a little deeper we could figure out that flatMap is stack-safe because Pull is essentially a wrapper for FreeC, which is trampolined.)

    It probably wouldn't be terribly hard to rewrite tk in terms of tailRecM, although we'd have to add the otherwise unnecessary ApplicativeError constraint. I'm guessing the authors of the documentation chose not to do that for clarity, and because they knew Pull's flatMap is fine.


    Update: here's a fairly mechanical tailRecM translation:

    import cats.ApplicativeError
    import fs2._
    
    def tk[F[_], O](n: Long)(implicit F: ApplicativeError[F, Throwable]): Pipe[F, O, O] =
      in => Pull.syncInstance[F, O].tailRecM((in, n)) {
        case (s, n) => s.pull.uncons.flatMap {
          case Some((hd, tl)) =>
            hd.size match {
              case m if m <= n => Pull.output(hd).as(Left((tl, n - m)))
              case m => Pull.output(hd.take(n.toInt)).as(Right(()))
            }
          case None => Pull.pure(Right(()))
        }
      }.stream
    

    Note that there's no explicit recursion.


    The answer to your second question depends on what the other method looks like, but in the case of your specific example, >> will just result in more flatMap layers, so it should be fine.

    To address your question more generally, this whole topic is a confusing mess in Scala. You shouldn't have to dig into implementations like we did above just to know whether a type supports stack-safe monadic recursion or not. Better conventions around documentation would be a help here, but unfortunately we're not doing a very good job of that. You could always use tailRecM to be "safe" (which is what you'll want to do when the F[_] is generic, anyway), but even then you're trusting that the Monad implementation is lawful.

    To sum up: it's a bad situation all around, and in sensitive situations you should definitely write your own tests to verify that implementations like this are stack-safe.