Search code examples
scalafor-comprehension

Scala for comprehension efficiency?


In the book "Programming In Scala", chapter 23, the author give an example like:

case class Book(title: String, authors: String*)
val books: List[Book] = // list of books, omitted here
// find all authors who have published at least two books

for (b1 <- books; b2 <- books if b1 != b2;
    a1 <- b1.authors; a2 <- b2.authors if a1 == a2)
yield a1

The author said, this will translated into:

books flatMap (b1 =>
   books filter (b2 => b1 != b2) flatMap (b2 =>
      b1.authors flatMap (a1 =>
        b2.authors filter (a2 => a1 == a2) map (a2 =>
           a1))))

But if you look into the map and flatmap method definition(TraversableLike.scala), you may find, they are defined as for loops:

   def map[B, That](f: A => B)(implicit bf: CanBuildFrom[Repr, B, That]): That = {
    val b = bf(repr)
    b.sizeHint(this) 
    for (x <- this) b += f(x)
    b.result
  }

  def flatMap[B, That](f: A => Traversable[B])(implicit bf: CanBuildFrom[Repr, B, That]): That = {
    val b = bf(repr)
    for (x <- this) b ++= f(x)
    b.result
  }

Well, I guess this for will continually be translated to foreach and then translated to while statement which is a construct not an expression, scala doesn't have a for construct, because it wants the for always yield something.

So, what I want to discuss with you is that, why does Scala do this "For translation" ? The author's example used 4 generators, which will be translated into 4 level nested for loop in the end, I think it'll have really horrible performance when the books is large.

Scala encourage people to use this kind of "Syntactic Sugar", you can always see codes that heavily make use of filter, map and flatmap, which seems programmers are forgetting what they really do is nesting one loop inside another, and what achieved is only to make codes looks a bit shorter. What's your idea?


Solution

  • For comprehensions are syntactic sugar for monadic transformation, and, as such, are useful in all sorts of places. At that, they are much more verbose in Scala than the equivalent Haskell construct (of course, Haskell is non-strict by default, so one can't talk about performance of the construct like in Scala).

    Also important, this construct keeps what is being done clear, and avoids quickly escalating indentation or unnecessary private method nesting.

    As to the final consideration, whether that hides the complexity or not, I'll posit this:

    for {
      b1 <- books
      b2 <- books
      if b1 != b2
      a1 <- b1.authors
      a2 <- b2.authors 
      if a1 == a2
    } yield a1
    

    It is very easy to see what is being done, and the complexity is clear: b^2 * a^2 (the filter won't alter the complexity), for number of books and number of authors. Now, write the same code in Java, either with deep indentation or with private methods, and try to ascertain, in a quick look, what the complexity of the code is.

    So, imho, this doesn't hide the complexity, but, on the contrary, makes it clear.

    As for the map/flatMap/filter definitions you mention, they do not belong to List or any other class, so they won't be applied. Basically,

    for(x <- List(1, 2, 3)) yield x * 2
    

    is translated into

    List(1, 2, 3) map (x => x * 2)
    

    and that is not the same thing as

    map(List(1, 2, 3), ((x: Int) => x * 2)))
    

    which is how the definition you passed would be called. For the record, the actual implementation of map on List is:

    def map[B, That](f: A => B)(implicit bf: CanBuildFrom[Repr, B, That]): That = {
      val b = bf(repr)
      b.sizeHint(this) 
      for (x <- this) b += f(x)
      b.result
    }