Search code examples
scalarecursionfunctional-programmingscala-cats

Cats: Non tail recursive tailRecM method for Monads


In cats, when a Monad is created using Monad trait, an implementation for method tailRecM should be provided.

I have a scenario below that I found impossible to provide a tail recursive implementation of tailRecM

  sealed trait Tree[+A]

  final case class Branch[A](left: Tree[A], right: Tree[A]) extends Tree[A]

  final case class Leaf[A](value: A) extends Tree[A]

  implicit val treeMonad = new Monad[Tree] {

    override def pure[A](value: A): Tree[A] = Leaf(value)

    override def flatMap[A, B](initial: Tree[A])(func: A => Tree[B]): Tree[B] =
      initial match {
        case Branch(l, r) => Branch(flatMap(l)(func), flatMap(r)(func))
        case Leaf(value) => func(value)
      }

    //@tailrec
    override def tailRecM[A, B](a: A)(func: (A) => Tree[Either[A, B]]): Tree[B] = {
      func(a) match {
        case Branch(l, r) =>
          Branch(
            flatMap(l) {
              case Right(l) => pure(l)
              case Left(l) => tailRecM(l)(func)
            },
            flatMap(r){
              case Right(r) => pure(r)
              case Left(r) => tailRecM(r)(func)
            }
          )

        case Leaf(Left(value)) => tailRecM(value)(func)

        case Leaf(Right(value)) => Leaf(value)
      }
    }
  }

1) According to the above example, how this tailRecM method can be used for optimizing flatMap method call? Does the implementation of the flatMap method is overridden/modified by tailRecM at the compile time ?

2) If the tailRecM is not tail recursive as above, will it still be efficient than using the original flatMap method ?

Please share your thoughts.


Solution

  • Relation between tailRecM and flatMap

    To answer you first question, the following code is part of FlatMapLaws.scala, from cats-laws. It tests consistency between flatMap and tailRecM methods.

    /**
     * It is possible to implement flatMap from tailRecM and map
     * and it should agree with the flatMap implementation.
     */
    def flatMapFromTailRecMConsistency[A, B](fa: F[A], fn: A => F[B]): IsEq[F[B]] = {
      val tailRecMFlatMap = F.tailRecM[Option[A], B](Option.empty[A]) {
        case None => F.map(fa) { a => Left(Some(a)) }
        case Some(a) => F.map(fn(a)) { b => Right(b) }
      }
    
      F.flatMap(fa)(fn) <-> tailRecMFlatMap
    }
    

    This shows how to implement a flatMap from tailRecM and implicitly suggests that the compiler will not do such thing automatically. It's up to the user of the Monad to decide when it makes sense to use tailRecM over flatMap.

    This blog has nice scala examples to explain when tailRecM comes in useful. It follows the PureScript article by Phil Freeman, which originally introduced the method.

    It explains the downsides in using flatMap for monadic composition:

    This characteristic of Scala limits the usefulness of monadic composition where flatMap can call monadic function f, which then can call flatMap etc..

    In contrast with a tailRecM-based implementation:

    This guarantees greater safety on the user of FlatMap typeclass, but it would mean that each the implementers of the instances would need to provide a safe tailRecM.

    Many of the provided methods in cats leverage monadic composition. So, even if you don't use it directly, implementing tailRecM allows for more efficient composition with other monads.

    Implmentation for tree

    In a different answer, @nazarii-bardiuk provides an implementation of tailRecM which is tail recursive, but does not pass the flatMap/tailRecM consistency test mentioned above. The tree structure is not properly rebuilt after recursion. A fixed version below:

    def tailRecM[A, B](arg: A)(func: A => Tree[Either[A, B]]): Tree[B] = {
      @tailrec
      def loop(toVisit: List[Tree[Either[A, B]]], 
               toCollect: List[Option[Tree[B]]]): List[Tree[B]] =
        toVisit match {
          case Branch(l, r) :: next =>
            loop(l :: r :: next, None :: toCollect)
    
          case Leaf(Left(value)) :: next =>
            loop(func(value) :: next, toCollect)
    
          case Leaf(Right(value)) :: next =>
            loop(next, Some(pure(value)) :: toCollect)
    
          case Nil =>
            toCollect.foldLeft(Nil: List[Tree[B]]) { (acc, maybeTree) =>
              maybeTree.map(_ :: acc).getOrElse {
                val left :: right :: tail = acc
                branch(left, right) :: tail
              }
            }
        }
    
      loop(List(func(arg)), Nil).head
    }
    

    (gist with test)

    You're probably aware, but your example (as well as the answer by @nazarii-bardiuk) is used in the book Scala with Cats by Noel Welsh and Dave Gurnell (highly recommended).