Search code examples
scalatail-recursionfold

Why can't Option.fold be used tail recursively in Scala?


Below, sumAllIf is tail recursive and sumAllFold is not. However, sumAllIf effectively has the same implementation. Is this a shortcoming of the Scala compiler (or of the Scala library), or am I overlooking something?

def maybeNext(in: Int): Option[Int] = if in < 10 then Some(in + 1) else None

// The Scala library implements Option.fold like this:
// @inline final def fold[B](ifEmpty: => B)(f: A => B): B =
//   if (isEmpty) ifEmpty else f(this.get)
@annotation.tailrec
def sumAllIf(current: Int, until: Int, sum: Int): Int =
  val nextOption = maybeNext(current)
  if (nextOption.isEmpty) sum else sumAllIf(nextOption.get, until, sum + nextOption.get)

// However, with Scala 3.1.0 and earlier, this is NOT tail recursive:
def sumAllFold(current: Int, until: Int, sum: Int): Int =
  maybeNext(current).fold(sum)(next => sumAllFold(next, until, sum + next))

@main def main(): Unit =
  println(sumAllIf(0, 10, 0))
  println(sumAllFold(0, 10, 0))

The issue is similar to question Scala @tailrec with fold, but here I'd like to find out why and whether this could be supported in the future.

The example is for Scala 3.1, but the issue itself is valid for Scala 2.x as well.


Solution

  • The recursive call happens inside a lambda. So it is not a tail recursive call unless the compiler would inline the fold and the lambda into your own method and only then test whether it is tail recursive. However the compiler does not do that automatically and it probably never will do that automatically.

    The good news is that in Scala 3 you can pretty easily work around that, and it is theoretically possible that the standard library will ever be adapted to take advantage of that. All it takes is explicitly implementing fold as an inline method with inline parameters.

    inline def fold[A, B](opt: Option[A])(inline onEmpty: B)(inline f: A => B): B =
      opt match
        case Some(a) => f(a)
        case None => onEmpty
    
    @annotation.tailrec
    def sumAllFold(current: Int, until: Int, sum: Int): Int =
      fold(maybeNext(current))(sum)(next => sumAllFold(next, until, sum + next))
    

    Note that an inline parameter automatically has by-name semantics, so onEmpty is already by-name without changing the type to => B.