Search code examples
scalarecursioncallstackstack-frametail-call-optimization

How to do tail call optimisation in Scala3?


I am trying to write a program that is 100% iterative, that is, the functions never need to return, because nothing must happen after such a return.

In other words, the program is 100% in tail position. Consider the following toy program:

  def foo(): Unit =
    bar()

  def bar(): Unit =
    foo()

  try
    foo()
  catch
    case s: StackOverflowError =>
      println(" Stack overflow!")

call foo indeed results in a stack overflow which is no surprise, indeed foo calls bar, as such bar needs a stack frame, bar then calls foo, which gain needs a stack frame, etc. It is clear why the stack overflow error occurs.

My question is, how can I define foo and bar as they are, without having a stack overflow? Languages like Scheme allow this program, they would run forever, yes, but the stack would not grow because it knows that nothing needs to happen after calling, e.g., bar from foo, so there is no need to keep the stack frame for foo upon the call to bar. Clearly scala (i.e., the JVM?) does keep the stack frame alive.

Now consider the next code example:

  def foo(): Unit = 
    foo()

  foo()

This program will run forever, but there will never occur a stack overflow.

I am aware of the @tailrec annotation, but to my understanding it would only be applicable to a situation like the second example, but not for the first example.

Any ideas? (I need the first example to run forever like the second example, without having stack overflow.)


Solution

  • As you note, the JVM forbids non-local jumps, thus if foo and bar are compiled as separate methods (which is generally desirable), tail call elimination is impossible.

    However, you can trampoline, by having your foo and bar return a value which the caller interprets as "call this function".

    sealed trait TrampolineInstruction[A]
    
    case class JumpOff[A](value: A) extends TrampolineInstruction[A]
    case class JumpAgain[A](thunk: => TrampolineInstruction[A])
      extends TrampolineInstruction[A]
    
    @tailrec
    def runTrampolined(ti: TrampolineInstruction[A]): A =
      ti match {
        case JumpOff(value) => value
        case JumpAgain(thunk) => runTrampolined(thunk)
      }
    
    def foo(): TrampolineInstruction[Unit] = JumpAgain(bar())
    def bar(): TrampolineInstruction[Unit] = JumpAgain(foo())
    
    runTrampolined(foo())  // will not overflow the stack, never completes
    

    Cats provides an Eval monad which encapsulates the idea of trampolining. The above definitions of foo and bar are then

    import cats.Eval
    
    def foo(): Eval[Unit] = Eval.defer(bar())
    def bar(): Eval[Unit] = Eval.defer(foo())
    
    foo().value  // consumes a bounded (very small, not necessarily 1) number of stack frames, never completes
    

    The monadic qualities of Eval may prove useful for expressing more complex logic without the risk of calling value in the middle of the chain.

    Note: JumpOff in the first snippet is basically Eval.Leaf (generally constructed using Eval.now) and JumpAgain is basically Eval.Defer.