Search code examples
scalatail-recursiontail-call-optimizationcontinuation-passingackermann

How to convert a variation of ackermann function to support tail call?


I'm currently solving a problem which is to implement a variation of ackermann function in scala with tail call optimization support so that the stack does not overflow.

The problem is, I cannot find a way to tail-call optimize it. I'm told continuation-pass-style(CPS) would help, but even though I successfully re-implemented it with CPS style I'm still lost.

The variation of ackermann function is as such:

ppa(p, a, b) = p(a, b)               (if a <= 0 or b <= 0)
ppa(p, a, b) = p(a, ppa(p, a-1, b))  (if p(a, b) is even and a, b > 0)
ppa(p, a, b) = p(ppa(p, a, b-1), b)  (if p(a, b) is odd and a, b > 0)

The code without optimization is as such:

def ppa(p: (Int, Int) => Int, a: Int, b: Int): Int = {
  def ppa_cont(a: Int, b: Int, ret: (Int, Int) => Int): Int = {
    if (a <= 0 || b <= 0) ret(a, b)
    else (a, b) match {
      case (_, _) if (p(a, b) % 2 == 0) => ret(a, ppa_cont(a-1, b, (x, y) => ret(x, y)))
      case (_, _) => ret(ppa_cont(a, b-1, (x, y) => ret(x, y)), b)
    }
  }

  ppa_cont(a, b, p)
}

Another trial is as such:

def ppa(p: (Int, Int) => Int, a: Int, b: Int): Int = {
  def ppa_cont(a: Int, b: Int, cont: (Int, Int) => Int): (Int, Int) => Int = {
    if (a <= 0 || b <= 0) cont
    else if (p(a, b) % 2 == 0) (a, b) => cont(a, ppa_cont(a-1, b, cont)(a-1, b))
    else (a, b) => cont(ppa_cont(a, b-1, cont)(a, b-1), b)
  }
 
  ppa_cont(a, b, p)(a, b)
}

I tried to tail-call optimize it as such:

def ppa(p: (Int, Int) => Int, a: Int, b: Int): Int = {
  @annotation.tailrec
  def ppa_cont(a: Int, b: Int, ret: (Int, Int) => TailRec[Int]): TailRec[Int] = {
    if (a <= 0 || b <= 0) tailcall(ret(a, b))
    else (a, b) match {
      case (_, _) if (p(a, b) % 2 == 0) => {
        tailcall(ret(a, ppa_cont(a-1, b, (x, y) => ret(x-1, y))))
      }
      case (_, _) => {
        tailcall(ret(ppa_cont(a, b-1, (x, y) => ret(x, y-1)), b))
      }
    }
  }

  val lifted: (Int, Int) => TailRec[Int] = (x, y) => done(p(x, y))

  ppa_cont(a, b, lifted).result
}

But this won't compile because of type mismatches.

What could be the problem? Am I going in a wrong direction? Little hints and helping hands will be appreciated. Thx :)

p.s. I got the hint from: why scala doesn't make tail call optimization?


Solution

  • Try cats.free.Trampoline or scala.util.control.TailCalls.TailRec. It's not @tailrec but stack-safe.

    import scala.util.control.TailCalls._
    
    def ppa(p: (Int, Int) => Int, a: Int, b: Int): Int = {
      def hlp(a: Int, b: Int): TailRec[Int] = {
        if (a <= 0 || b <= 0) done(p(a, b))
        else if (p(a, b) % 2 == 0) tailcall(hlp(a - 1, b)).map(p(a, _))
        else tailcall(hlp(a, b - 1)).map(p(_, b))
      }
    
      hlp(a, b).result
    }
    

    http://eed3si9n.com/herding-cats/stackless-scala-with-free-monads.html

    http://eed3si9n.com/herding-cats/tail-recursive-monads.html

    Actually your function doesn't look like Ackermann. Actual Ackermann makes two recursive calls

    f(m, n) = f(m - 1, f(m, n - 1))
    

    Your function makes single recursive call. It's not hard to write iterative version of your function (usually tail recursion is used because compiler can transform it to iterative version automatically). Suppose we already calculated ppa(i, j) for 0 <= i <= a - 1, 0 <= j <= b - 1 (yellow area). Then we calculate two orange segments (a, 0), (a, 1), ..., (a, b - 1) (in this order) and (0, b), (1, b), ..., (a - 1, b) (in this order). Then we calculate red cell (a, b).

    enter image description here