Search code examples
asynchronousrecursionkotlinsuspendkotlin-coroutines

Kotlin suspend function recursive call


Suddenly discover that recursive call of suspend function takes more time then calling the same function but without suspend modifier, so please consider the code snippet below (basic Fibonacci series calculation):

suspend fun asyncFibonacci(n: Int): Long = when {
    n <= -2 -> asyncFibonacci(n + 2) - asyncFibonacci(n + 1)
    n == -1 -> 1
    n == 0 -> 0
    n == 1 -> 1
    n >= 2 -> asyncFibonacci(n - 1) + asyncFibonacci(n - 2)
    else -> throw IllegalArgumentException()
}

If I call this function and measure its execution time with code below:

fun main(args: Array<String>) {
    val totalElapsedTime = measureTimeMillis {
        val nFibonacci = 40

        val deferredFirstResult: Deferred<Long> = async {
            asyncProfile("fibonacci") { asyncFibonacci(nFibonacci) } as Long
        }
        val deferredSecondResult: Deferred<Long> = async {
            asyncProfile("fibonacci") { asyncFibonacci(nFibonacci) } as Long
        }

        val firstResult: Long = runBlocking { deferredFirstResult.await() }
        val secondResult: Long = runBlocking { deferredSecondResult.await() }
        val superSum = secondResult + firstResult
        println("${thread()} - Sum of two $nFibonacci'th fibonacci numbers: $superSum")
    }
    println("${thread()} - Total elapsed time: $totalElapsedTime millis")
}

I observe further results:

commonPool-worker-2:fibonacci - Start calculation...
commonPool-worker-1:fibonacci - Start calculation...
commonPool-worker-2:fibonacci - Finish calculation...
commonPool-worker-2:fibonacci - Elapsed time: 7704 millis
commonPool-worker-1:fibonacci - Finish calculation...
commonPool-worker-1:fibonacci - Elapsed time: 7741 millis
main - Sum of two 40'th fibonacci numbers: 204668310
main - Total elapsed time: 7816 millis

But if I remove suspend modifier from asyncFibonacci function, I'll have this result:

commonPool-worker-2:fibonacci - Start calculation...
commonPool-worker-1:fibonacci - Start calculation...
commonPool-worker-1:fibonacci - Finish calculation...
commonPool-worker-1:fibonacci - Elapsed time: 1179 millis
commonPool-worker-2:fibonacci - Finish calculation...
commonPool-worker-2:fibonacci - Elapsed time: 1201 millis
main - Sum of two 40'th fibonacci numbers: 204668310
main - Total elapsed time: 1250 millis

I know that's better to rewrite such a function with tailrec it will increase its execution time apx. almost in 100 times, but anyway, what this suspend key word does that decrease execution speed from 1 second to 8 seconds?

Is it totally stupid idea to mark recursive functions with suspend?


Solution

  • As an introductory comment, your testing code setup is too complex. This much simpler code achieves the same in terms of stressing suspend fun recursion:

    fun main(args: Array<String>) {
        launch(Unconfined) {
            val nFibonacci = 37
            var sum = 0L
            (1..1_000).forEach {
                val took = measureTimeMillis {
                    sum += suspendFibonacci(nFibonacci)
                }
                println("Sum is $sum, took $took ms")
            }
        }
    }
    
    suspend fun suspendFibonacci(n: Int): Long {
        return when {
            n >= 2 -> suspendFibonacci(n - 1) + suspendFibonacci(n - 2)
            n == 0 -> 0
            n == 1 -> 1
            else -> throw IllegalArgumentException()
        }
    }
    

    I tried to reproduce its performance by writing a plain function that approximates the kinds of things the suspend function must do to achieve suspendability:

    val COROUTINE_SUSPENDED = Any()
    
    fun fakeSuspendFibonacci(n: Int, inCont: Continuation<Unit>): Any? {
        val cont = if (inCont is MyCont && inCont.label and Integer.MIN_VALUE != 0) {
            inCont.label -= Integer.MIN_VALUE
            inCont
        } else MyCont(inCont)
        val suspended = COROUTINE_SUSPENDED
        loop@ while (true) {
            when (cont.label) {
                0 -> {
                    when {
                        n >= 2 -> {
                            cont.n = n
                            cont.label = 1
                            val f1 = fakeSuspendFibonacci(n - 1, cont)!!
                            if (f1 === suspended) {
                                return f1
                            }
                            cont.data = f1
                            continue@loop
                        }
                        n == 1 || n == 0 -> return n.toLong()
                        else -> throw IllegalArgumentException("Negative input not allowed")
                    }
                }
                1 -> {
                    cont.label = 2
                    cont.f1 = cont.data as Long
                    val f2 = fakeSuspendFibonacci(cont.n - 2, cont)!!
                    if (f2 === suspended) {
                        return f2
                    }
                    cont.data = f2
                    continue@loop
                }
                2 -> {
                    val f2 = cont.data as Long
                    return cont.f1 + f2
                }
                else -> throw AssertionError("Invalid continuation label ${cont.label}")
            }
        }
    }
    
    class MyCont(val completion: Continuation<Unit>) : Continuation<Unit> {
        var label = 0
        var data: Any? = null
        var n: Int = 0
        var f1: Long = 0
    
        override val context: CoroutineContext get() = TODO("not implemented")
        override fun resumeWithException(exception: Throwable) = TODO("not implemented")
        override fun resume(value: Unit) = TODO("not implemented")
    }
    

    You have to invoke this one with

    sum += fakeSuspendFibonacci(nFibonacci, InitialCont()) as Long
    

    where InitialCont is

    class InitialCont : Continuation<Unit> {
        override val context: CoroutineContext get() = TODO("not implemented")
        override fun resumeWithException(exception: Throwable) = TODO("not implemented")
        override fun resume(value: Unit) = TODO("not implemented")
    }
    

    Basically, to compile a suspend fun the compiler has to turn its body into a state machine. Each invocation must also create an object to hold the machine's state. When you resume, the state object tells which state handler to go to. The above still isn't all there is to it, the real code is even more complex.

    In intepreted mode (java -Xint), I get almost the same performance as the actual suspend fun, and it is less than twice as fast than the real one with JIT enabled. By comparison, the "direct" function implementation is about 10 times as fast. That means that the code shown explains a good part of the overhead of suspendability.