Search code examples
scalafor-comprehension

Scala - unexpected type switch from Map to Iterable in for comprehension?


I'm confused by the typing going on behind the scenes in for comprehensions over maps. My understanding is that the outer collection type is usually supposed to be preserved, and we see that expected behavior in the following two cases:

scala> for {
     |   (k,v) <- Map(0->1,2->3)
     | } yield k -> v
res0: scala.collection.immutable.Map[Int,Int] = Map(0 -> 1, 2 -> 3)

scala> for {
     |   (k,v) <- Map(0->1,2->3)
     |   foo = 1
     | } yield k -> v
res1: scala.collection.immutable.Map[Int,Int] = Map(0 -> 1, 2 -> 3)

But when I add a second assignment inside the for comprehension I get something surprising:

scala> for {
     |   (k,v) <- Map(0->1,2->3)
     |   foo = 1
     |   bar = 2
     | } yield k -> v
res2: scala.collection.immutable.Iterable[(Int, Int)] = List((0,1), (2,3))

Why is this happening?


Solution

  • If you run scala -Xprint:typer -e "for { ... } yield k->v", you can get a de-sugared version of the code. Here is a very much simplified version of what you get:

    val m: Map[Int,Int] = Map(0->1, 2->3)
    m.map {
      case x @ (k,v) =>
        val foo = 1
        val bar = 2
        (x, foo, bar)
    }.map {
      case ((k,v), foo, bar) => (k, v)
    }
    

    So what you'll notice is that when the for-comprehension gets converted to a .map call, it's actually returning foo and bar along with k->v, which means it's a Tuple3[(Int,Int), Int, Int]. Since an iterable of Tuple3 objects can't be turned into a Map, it assumes it must return an Iterable. However, in order to get the correct output, which is a collection of Tuple2 objects, it performs a secondary .map that discards foo and bar from the Tuple3, but at this point it no longer knows that it should have been a Map because when you chain calls to .map that information doesn't get carried forward.

    Your example with only one assignment just gets lucky because the intermediate representation is Tuple2[(Int,Int), Int].

    On the other hand, if you use a .map directly, it works:

    Map(0->1, 2->3).map { 
      case (k,v) =>
        val foo = 1
        val bar = 2 
        k -> v
    }