Search code examples
scalacontinuationstail-call-optimization

why scala doesn't make tail call optimization?


Just playing with continuations. The goal is to create function which will receive another function as parameter, and execution amount - and return function which will apply parameter given amount times.

The implementation looks pretty obvious

def n_times[T](func:T=>T,count:Int):T=>T = {
  @tailrec
  def n_times_cont(cnt:Int, continuation:T=>T):T=>T= cnt match {
        case _ if cnt < 1 => throw new IllegalArgumentException(s"count was wrong $count")
        case 1 => continuation
        case _ => n_times_cont(cnt-1,i=>continuation(func(i)))
      }
  n_times_cont(count, func)
}

def inc (x:Int) = x+1

    val res1 = n_times(inc,1000)(1)  // Works OK, returns 1001

val res = n_times(inc,10000000)(1) // FAILS

But there is no problem - this code fails with StackOverflow error. Why there is no tail-call optimization here?

I'm running it in Eclipse using Scala plugin, and it returns Exception in thread "main" java.lang.StackOverflowError at scala.runtime.BoxesRunTime.boxToInteger(Unknown Source) at Task_Mult$$anonfun$1.apply(Task_Mult.scala:25) at Task_Mult$$anonfun$n_times_cont$1$1.apply(Task_Mult.scala:18)

p.s.

F# code, which is almost direct translation, is working without any issues

let n_times_cnt func count = 
    let rec n_times_impl count' continuation = 
        match count' with
        | _ when count'<1 -> failwith "wrong count"
        | 1 -> continuation
        | _ -> n_times_impl (count'-1) (func >> continuation) 
    n_times_impl count func

let inc x = x+1
let res = (n_times_cnt inc 10000000) 1

printfn "%o" res

Solution

  • The Scala standard library has an implementation of trampolines in scala.util.control.TailCalls. So revisiting your implementation... When you build up the nested calls with continuation(func(t)), those are tail calls, just not optimized by the compiler. So, let's build up a T => TailRec[T], where the stack frames will be replaced with objects in the heap. Then return a function that will take the argument and pass it to that trampolined function:

    import util.control.TailCalls._
    def n_times_trampolined[T](func: T => T, count: Int): T => T = {
      @annotation.tailrec
      def n_times_cont(cnt: Int, continuation: T => TailRec[T]): T => TailRec[T] = cnt match {
        case _ if cnt < 1 => throw new IllegalArgumentException(s"count was wrong $count")
        case 1 => continuation
        case _ => n_times_cont(cnt - 1, t => tailcall(continuation(func(t))))
      }
      val lifted : T => TailRec[T] = t => done(func(t))
      t => n_times_cont(count, lifted)(t).result
    }