Search code examples
scalaclojureperformancetail-recursiontail-call-optimization

Why is Clojure much faster than Scala on a recursive add function?


A friend gave me this code snippet in Clojure

(defn sum [coll acc] (if (empty? coll) acc (recur (rest coll) (+ (first coll) acc))))
(time (sum (range 1 9999999) 0))

and asked me how does it fare against a similar Scala implementation.

The Scala code I've written looks like this:

def from(n: Int): Stream[Int] = Stream.cons(n, from(n+1))
val ints = from(1).take(9999998)

def add(a: Stream[Int], b: Long): Long = {
    if (a.isEmpty) b else add(a.tail, b + a.head)
}

val t1 = System.currentTimeMillis()
println(add(ints, 0))
val t2 = System.currentTimeMillis()
println((t2 - t1).asInstanceOf[Float] + " msecs")

Bottom line is: the code in Clojure runs in about 1.8 seconds on my machine and uses less than 5MB of heap, the code in Scala runs in about 12 seconds and 512MB of heap aren't enough (it finishes the computation if I set the heap to 1GB).

So I'm wondering why is Clojure so much faster and slimmer in this particular case? Do you have a Scala implementation that has a similar behavior in terms of speed and memory usage?

Please refrain from religious remarks, my interest lies in finding out primarily what makes clojure so fast in this case and if there's a faster implementation of the algo in scala. Thanks.


Solution

  • First, Scala only optimises tail calls if you invoke it with -optimise. Edit: It seems Scala will always optimise tail-call recursions if it can, even without -optimise.

    Second, Stream and Range are two very different things. A Range has a beginning and an end, and its projection has just a counter and the end. A Stream is a list which will be computed on-demand. Since you are adding the whole ints, you'll compute, and, therefore, allocate, the whole Stream.

    A closer code would be:

    import scala.annotation.tailrec
    
    def add(r: Range) = {
      @tailrec 
      def f(i: Iterator[Int], acc: Long): Long = 
        if (i.hasNext) f(i, acc + i.next) else acc
    
      f(r iterator, 0)
    }
    
    def time(f: => Unit) {
      val t1 = System.currentTimeMillis()
      f
      val t2 = System.currentTimeMillis()
      println((t2 - t1).asInstanceOf[Float]+" msecs")
    }
    

    Normal run:

    scala> time(println(add(1 to 9999999)))
    49999995000000
    563.0 msecs
    

    On Scala 2.7 you need "elements" instead of "iterator", and there's no "tailrec" annotation -- that annotation is used just to complain if a definition can't be optimized with tail recursion -- so you'll need to strip "@tailrec" as well as the "import scala.annotation.tailrec" from the code.

    Also, some considerations on alternate implementations. The simplest:

    scala> time(println(1 to 9999999 reduceLeft (_+_)))
    -2014260032
    640.0 msecs
    

    On average, with multiple runs here, it is slower. It's also incorrect, because it works just with Int. A correct one:

    scala> time(println((1 to 9999999 foldLeft 0L)(_+_)))
    49999995000000
    797.0 msecs
    

    That's slower still, running here. I honestly wouldn't have expected it to run slower, but each interation calls to the function being passed. Once you consider that, it's a pretty good time compared to the recursive version.