Search code examples
algorithmscalarecursiontail-recursion

Pascal's Triangle Scala: Compute elements of Pascal's triangle using tail recursive approach


In Pascal's Triangle the number at the edge of the triangle are all 1, and each number inside the triangle is the sum of the two numbers above it. A sample Pascal's triangle would look like below.

    1
   1 1
  1 2 1
 1 3 3 1
1 4 6 4 1

I wrote a program that computes the elements of Pascal's triangle using below technique.

/**
* Can I make it tail recursive???
*
* @param c column
* @param r row
* @return 
*/
def pascalTriangle(c: Int, r: Int): Int = {
  if (c == 0 || (c == r)) 1
  else
    pascalTriangle(c-1, r-1) + pascalTriangle(c, r - 1)
}

So, for example if

i/p: pascalTriangle(0,2)  
o/p: 1.

i/p: pascalTriangle(1,3)  
o/p: 3.

Above program is correct and giving the correct output as expected. My question is, is it possible to write tail recursive version of above algorithm? How?


Solution

  • Try

    def pascalTriangle(c: Int, r: Int): Int = {
      @tailrec
      def loop(c0: Int, r0: Int, pred: Array[Int], cur: Array[Int]): Int = {
        cur(c0) = (if (c0 > 0) pred(c0 - 1) else 0) + (if (c0 < r0) pred(c0) else 0)
    
        if ((c0 == c) && (r0 == r)) cur(c0)
        else if (c0 < r0) loop(c0 + 1, r0, pred, cur)
        else loop(0, r0 + 1, cur, new Array(_length = r0 + 2))
      }
    
      if ((c == 0) && (r == 0)) 1
      else loop(0, 1, Array(1), Array(0, 0))
    }
    

    or

    import scala.util.control.TailCalls._
    
    def pascalTriangle(c: Int, r: Int): Int = {
      def hlp(c: Int, r: Int): TailRec[Int] =
        if (c == 0 || (c == r)) done(1)
        else for {
          x <- tailcall(hlp(c - 1, r - 1))
          y <- tailcall(hlp(c, r - 1))
        } yield (x + y)
    
      hlp(c, r).result
    }
    

    or

    import cats.free.Trampoline
    import cats.free.Trampoline.{defer, done}
    import cats.instances.function._
    
    def pascalTriangle(c: Int, r: Int): Int = {
      def hlp(c: Int, r: Int): Trampoline[Int] =
        if (c == 0 || (c == r)) done(1)
        else for {
          x <- defer(hlp(c - 1, r - 1))
          y <- defer(hlp(c, r - 1))
        } yield (x + y)
    
      hlp(c, r).run
    }
    

    http://eed3si9n.com/herding-cats/stackless-scala-with-free-monads.html