Search code examples
scalarecursiontail-recursiontail-call-optimization

Get results for sub problems using tail call recursion in Scala


I am trying to calculate results for each sub problem using @tailrec similar to how normal recursive solutions can produce solutions for each sub problem. Following is the example I worked on.

@tailrec
  def collatz(
      n: BigInt,
      acc: BigInt,
      fn: (BigInt, BigInt) => Unit
  ): BigInt = {
    fn(n, acc)
    if (n == 1) {
      acc
    } else if (n % 2 == 0) {
      collatz(n / 2, acc + 1, fn)
    } else {
      collatz(3 * n + 1, acc + 1, fn)
    }
  }

Here I am calculating the count of a number when it reaches 1 using Collatz Conjecture. Just for an example let us assume it for number 32

val n = BigInt("32")
    val c = collatz(n, 0, (num, acc) => {
      println("Num -> " + num + " " + " " + "Acc -> " + acc)
    })

I am getting the following output.

Num -> 32  Acc -> 0
Num -> 16  Acc -> 1
Num -> 8  Acc -> 2
Num -> 4  Acc -> 3
Num -> 2  Acc -> 4
Num -> 1  Acc -> 5

Normal recursive solution will return exact count for each number. For instance number 2 reaches 1 in 1 step. Thus each sub problem has exact solution but in a tailrec method only final result is computed correctly. The variable acc behaves exactly like a loop variable as expected.

How can I change the code that is tail call optimized at the same time I can get exact value to the each sub problem. In simple words, how can I attain Stack type of behavior for acc variable.

Also, one related question how large will be the overhead of lambda function fn for large values of n assuming println statement will not be used.

I am adding a recursive solution that can produce correct solution for the sub problem.

def collatz2(
      n: BigInt,
      fn: (BigInt, BigInt) => Unit
  ): BigInt = {

    val c: BigInt = if (n == 1) {
      0
    } else if (n % 2 == 0) {
      collatz2(n / 2, fn) + 1
    } else {
      collatz2(3 * n + 1, fn) + 1
    }
    fn(n, c)
    c
  }

It produces the following output.

Num -> 1  Acc -> 0
Num -> 2  Acc -> 1
Num -> 4  Acc -> 2
Num -> 8  Acc -> 3
Num -> 16  Acc -> 4
Num -> 32  Acc -> 5

Solution

  • I'm not sure I understood your question correctly. It sounds like you are asking us to write collatz2 so that it is tail recursive. I have rewritten it in two ways.

    Although I have provided two solutions, they are really the same thing. One uses a List as a stack, where the head of the List is the top of the stack. The other uses the mutable.Stack data structure. Study the two solutions until you can see why they are both the same as collatz2 in the original question.

    To make the program tail recursive, what we have to do is to simulate the effect of pushing values onto a stack, and then popping them off one by one. It is during the pop phase that we give the value for Acc. (For those who don't remember, Acc in Hariharan's parlance is the index of each term.)

    import scala.collection.mutable
    
    object CollatzCount {
    
      def main(args: Array[String]) = {
        val start = 32
    
        collatzFinalList(start, printer)
    
        collatzFinalStack(start, printer)
    
      }
    
      def collatzInnerList(n: Int, acc: List[Int]): List[Int] = {
        if (n == 1) n :: acc
        else if (n % 2 == 0) collatzInnerList(n/2, n :: acc )
        else collatzInnerList(3*n + 1, n :: acc )
      }
    
      def collatzFinalList(n: Int, fun: (Int, Int)=>Unit): Unit = {
        val acc = collatzInnerList(n, List())
        acc.foldLeft(0){ (ctr, e) =>
          fun(e, ctr)
          ctr + 1
        }
      }
    
      def collatzInnerStack(n: Int, stack: mutable.Stack[Int]): mutable.Stack[Int] = {
        if (n == 1) {
          stack.push(n)
          stack
        } else if (n % 2 == 0) {
          stack.push(n)
          collatzInnerStack(n/2, stack)
        } else {
          stack.push(n)
          collatzInnerStack(3*n + 1, stack)
        }
      }
    
      def popStack(ctr: Int, stack: mutable.Stack[Int], fun: (Int, Int)=>Unit): Unit = {
        if (stack.nonEmpty) {
          val popped = stack.pop
          fun(popped, ctr)
          popStack(ctr + 1, stack, fun)
        } else ()
      }
    
    
      def collatzFinalStack(n: Int, fun: (Int, Int) => Unit): Unit = {
        val stack = collatzInnerStack(n, mutable.Stack())
        popStack(0, stack, fun)
      }
    
    
      val printer = (x: Int, y: Int) => println("Num ->" + x + " " + " " + "Acc -> " + y)
    
    }