I've wrote a naïve test-bed to measure the performance of three kinds of factorial implementation: loop based, non tail-recursive and tail-recursive.
Surprisingly to me the worst performant was the loop ones («while» was expected to be more efficient so I provided both) that cost almost twice than the tail recursive alternative.
*ANSWER: fixing the loop implementation avoiding the = operator which outperform worst with BigInt due to its internals «loops» became fastest as expected
Another «woodoo» behavior I've experienced was the StackOverflow exception which wasn't thrown systematically for the same input in the case of non-tail recursive implementation. I can circumvent the StackOverlow by progressively call the function with larger and larger values… I feel crazy :) Answer: JVM require to converge during startup, then behavior is coherent and systematic
This is the code:
final object Factorial {
type Out = BigInt
def calculateByRecursion(n: Int): Out = {
require(n>0, "n must be positive")
n match {
case _ if n == 1 => return 1
case _ => return n * calculateByRecursion(n-1)
}
}
def calculateByForLoop(n: Int): Out = {
require(n>0, "n must be positive")
var accumulator: Out = 1
for (i <- 1 to n)
accumulator = i * accumulator
accumulator
}
def calculateByWhileLoop(n: Int): Out = {
require(n>0, "n must be positive")
var accumulator: Out = 1
var i = 1
while (i <= n) {
accumulator = i * accumulator
i += 1
}
accumulator
}
def calculateByTailRecursion(n: Int): Out = {
require(n>0, "n must be positive")
@tailrec def fac(n: Int, acc: Out): Out = n match {
case _ if n == 1 => acc
case _ => fac(n-1, n * acc)
}
fac(n, 1)
}
def calculateByTailRecursionUpward(n: Int): Out = {
require(n>0, "n must be positive")
@tailrec def fac(i: Int, acc: Out): Out = n match {
case _ if i == n => n * acc
case _ => fac(i+1, i * acc)
}
fac(1, 1)
}
def comparePerformance(n: Int) {
def showOutput[A](msg: String, data: (Long, A), showOutput:Boolean = false) =
showOutput match {
case true => printf("%s returned %s in %d ms\n", msg, data._2.toString, data._1)
case false => printf("%s in %d ms\n", msg, data._1)
}
def measure[A](f:()=>A): (Long, A) = {
val start = System.currentTimeMillis
val o = f()
(System.currentTimeMillis - start, o)
}
showOutput ("By for loop", measure(()=>calculateByForLoop(n)))
showOutput ("By while loop", measure(()=>calculateByWhileLoop(n)))
showOutput ("By non-tail recursion", measure(()=>calculateByRecursion(n)))
showOutput ("By tail recursion", measure(()=>calculateByTailRecursion(n)))
showOutput ("By tail recursion upward", measure(()=>calculateByTailRecursionUpward(n)))
}
}
What follows is some output from sbt console (Before «while» implementation):
scala> example.Factorial.comparePerformance(10000)
By loop in 3 ns
By non-tail recursion in >>>>> StackOverflow!!!!!… see later!!!
........
scala> example.Factorial.comparePerformance(1000)
By loop in 3 ms
By non-tail recursion in 1 ms
By tail recursion in 4 ms
scala> example.Factorial.comparePerformance(5000)
By loop in 105 ms
By non-tail recursion in 27 ms
By tail recursion in 34 ms
scala> example.Factorial.comparePerformance(10000)
By loop in 236 ms
By non-tail recursion in 106 ms >>>> Now works!!!
By tail recursion in 127 ms
scala> example.Factorial.comparePerformance(20000)
By loop in 977 ms
By non-tail recursion in 495 ms
By tail recursion in 564 ms
scala> example.Factorial.comparePerformance(30000)
By loop in 2285 ms
By non-tail recursion in 1183 ms
By tail recursion in 1281 ms
What follows is some output from sbt console (After «while» implementation):
scala> example.Factorial.comparePerformance(10000)
By for loop in 252 ms
By while loop in 246 ms
By non-tail recursion in 130 ms
By tail recursion in 136 ns
scala> example.Factorial.comparePerformance(20000)
By for loop in 984 ms
By while loop in 1091 ms
By non-tail recursion in 508 ms
By tail recursion in 560 ms
What follows is some output from sbt console (after «upward» tail recursion implementation) the world come back sane:
scala> example.Factorial.comparePerformance(10000)
By for loop in 259 ms
By while loop in 229 ms
By non-tail recursion in 114 ms
By tail recursion in 119 ms
By tail recursion upward in 105 ms
scala> example.Factorial.comparePerformance(20000)
By for loop in 1053 ms
By while loop in 957 ms
By non-tail recursion in 513 ms
By tail recursion in 565 ms
By tail recursion upward in 470 ms
What follows is some output from sbt console after fixing BigInt multiplication in «loops»: the world is totally sane:
scala> example.Factorial.comparePerformance(20000)
By for loop in 498 ms
By while loop in 502 ms
By non-tail recursion in 521 ms
By tail recursion in 611 ms
By tail recursion upward in 503 ms
BigInt overhead and a stupid implementation by me masked the expected behavior.
PS.: In the end I should re-title this post to «A lernt lesson on BigInts»
For loops are not actually quite loops; they're for comprehensions on a range. If you actually want a loop, you need to use while
. (Actually, I think the BigInt
multiplication here is heavyweight enough so it shouldn't matter. But you'll notice if you're multiplying Int
s.)
Also, you have confused yourself by using BigInt
. The bigger your BigInt
is, the slower your multiplication. So your non-tail loop counts up while your tail recursion loop counds down which means that the latter has more big numbers to multiply.
If you fix these two issues you will find that sanity is restored: loops and tail recursion are the same speed, with both regular recursion and for
slower. (Regular recursion may not be slower if the JVM optimization makes it equivalent)
(Also, the stack overflow fix is probably because the JVM starts inlining and may either make the call tail-recursive itself, or unrolls the loop far enough so that you don't overflow any longer.)
Finally, you're getting poor results with for and while because you're multiplying on the right rather than the left with the small number. It turns out that the Java's BigInt multiplies faster with the smaller number on the left.