Search code examples
scalarecursionstack-overflowtail-recursion

Scala partial tail recursion


Because I am defining an interpreter with a lot of variables, I am writing this:

type Context = Map[String, Int]
abstract class Expr
case class Let(varname: String, varvalue: Expr, body: Expr) extends Expr
case class Var(name: String) extends Expr
case class Plus(a: Expr, b: Expr) extends Expr
case class Num(i: Int) extends Expr

def eval(expr: Expr)(implicit ctx: Context): Int = expr match {
  case Let(i, e, b) => eval(b)(ctx + (i -> eval(e)))
  case Var(s) => ctx(s)
  case Num(i) => i
  case Plus(a, b) => eval(a) + eval(b)
}

For very long expressions this fails because of StackOverflowException, for expressions of the type:

Let("a", 1, 
Let("b", Plus("a", "a"), 
Let("c", Plus("b", "a"), 
Let("d", 1,  ...  )

However, once the value of a variable is defined, I just need to call the evaluator again on the body of the Let, it seems to me that it should just do some kind of partial tail-recursion.
How is it possible to achieve partial tail recursion in Scala?


Solution

  • You want some way of getting tail-call optimizations on only some of the branches of eval. I don't think this is possible - the most Scala will do is accept a @tailrec annotation to a method as a whole and fail at compile time if it can't optimize the method into a loop.

    However, making this iterative to take advantage of the the tail-call with the Let is pretty straight forward:

    def eval(expr: Expr, ctx: Context): Int = {
    
      // The expression/context pair we try to reduce at every loop iteration
      var exprM = expr;
      var ctxM = ctx;
    
      while (true) {
        expr match {
          case Var(s) => return ctxM(s)
          case Num(i) => return i
          case Plus(a, b) => return eval(a,ctxM) + eval(b,ctxM)
          case Let(i, e, b) => {
            ctxM += i -> eval(e,ctxM). // Update ctxM
            exprM = b                  // Update exprM
          }
        }
      }
      return 0; // unreachable, but Scala complains otherwise I'm not returning 'Int'
    } 
    

    Note this won't solve the stack overflows due to long chains of Pluss - there really isn't much we can do with those because the recursive calls are not in tail position.

    There was a time I thought Scala would make some @tailcall annotation to deal with this sort of thing, but I am not sure there is that much interest in such things anymore.