Search code examples
javakotlintail-recursion

Why does adding tailrec make a incorrect kotlin corecursion work?


I came across this interesting problem while reading Joy of Kotlin book. In chapter 4, while explaining tail recursion author provides an implementation for adding two numbers as below.

tailrec fun add(a: Int, b: Int): Int = if (b == 0) a else add(inc(a), dec(b))

# where
fun inc(n: Int) = n + 1
fun dec(n: Int) = n - 1

What's interesting about this function is add(3, -3) returns 0 but runs into stackoverflow if the keyword tailrec is removed. How can it return correct answer when the program seems incomplete.

I decompiled the java bytecode to see how the tail call elimination is done and this is what I saw.

public static final int add(int a, int b) {
      while(b != 0) {
         a = inc(a);
         b = dec(b);
      }
      return a;
   }

If I mentally walkthrough the code, the loop or the previous recursive call should result in infinite loop because variable b will never become zero as starting value itself is negative. However, running the above kotlin code or the java code provides correct result. The same code when run with debugger runs into infinte loop as I am expecting with mental walkthrough. How can this code give me correct result when run but run into infinite loop when in debug mode?

I personally think the correct implementation should be as below but I am not able to reason why first one is correct.

tailrec fun add(a: Int, b: Int): Int =
    if (b == 0) a
    else if (b > 0) add(inc(a), dec(b))
    else add(dec(a), inc(b))

Edit:

@broot's answer is correct. Verified with following code

tailrec fun add(a: Long, b: Int): Long =
    when {
        (b == 0) -> a
        (b > 0) -> add(inc(a), dec(b))
        else -> add(dec(a), inc(b))
    }

fun inc(n: Long) = n + 1
fun dec(n: Long): Long = n - 1

fun inc(n: Int) = n + 1
fun dec(n: Int) = n - 1

fun main() {
    println(add(3, -4))
    println(add(4, -3))
}

Solution

  • b overflows from Int.MIN_VALUE to Int.MAX_VALUE and then goes to 0. At the same time a goes the opposite way. Because we need 2^32 - 3 iterations to get b from -3 to 0, a is correctly decreased by 3. Also, due to such huge number of iterations, it can't work without tailrec.

    We can easily verify this by changing the a side to use longs:

    tailrec fun add(a: Long, b: Int): Long = if (b == 0) a else add(inc(a), dec(b))
    
    fun inc(n: Long) = n + 1
    fun dec(n: Int) = n - 1
    

    In this case the result is 4294967296.

    Even if it calculates the value correctly, it is probably not a good idea to use it like this. This is working correctly pretty much by accident.