Search code examples
scalabreakfor-comprehensioninfinite-sequence

How to cut a for-comprehension short (break out of it) in scala?


I have a piece of code which would code as follows:

val e2 = for (e <- elements if condition(expensiveFunction(e))) yield {
            expensiveFunction(e)
         }

Where the condition will be true for a few elements and then become false for all remaining ones.

Unfortunately this doesn't work (even if I ignore performance) because my elements is an infinite iterator.

Is there a way to use a "break" in a for-comprehension so it stops yielding elements when a certain condition holds? Otherwise, what would be the scala-idiomatic way to compute my e2?


Solution

  • scala> def compute(i: Int) = { println(s"f$i"); 10*i }
    compute: (i: Int)Int
    
    scala> for (x <- Stream range (0, 20)) yield compute(x)
    f0
    res0: scala.collection.immutable.Stream[Int] = Stream(0, ?)
    
    scala> res0 takeWhile (_ < 100)
    res1: scala.collection.immutable.Stream[Int] = Stream(0, ?)
    
    scala> res1.toList
    f1
    f2
    f3
    f4
    f5
    f6
    f7
    f8
    f9
    f10
    res2: List[Int] = List(0, 10, 20, 30, 40, 50, 60, 70, 80, 90)
    

    Edit, another demonstration:

    scala> def compute(i: Int) = { println(s"f$i"); 10*i }
    compute: (i: Int)Int
    
    scala> for (x <- Stream range (0, 20)) yield compute(x)
    f0
    res0: scala.collection.immutable.Stream[Int] = Stream(0, ?)
    
    scala> res0 takeWhile (_ < 100)
    res1: scala.collection.immutable.Stream[Int] = Stream(0, ?)
    
    scala> res1.toList
    f1
    f2
    f3
    f4
    f5
    f6
    f7
    f8
    f9
    f10
    res2: List[Int] = List(0, 10, 20, 30, 40, 50, 60, 70, 80, 90)
    
    scala> Stream.range(0,20).map(compute).toList
    f0
    f1
    f2
    f3
    f4
    f5
    f6
    f7
    f8
    f9
    f10
    f11
    f12
    f13
    f14
    f15
    f16
    f17
    f18
    f19
    res3: List[Int] = List(0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190)
    
    scala> Stream.range(0,20).map(compute).takeWhile(_ < 100).toList
    f0
    f1
    f2
    f3
    f4
    f5
    f6
    f7
    f8
    f9
    f10
    res4: List[Int] = List(0, 10, 20, 30, 40, 50, 60, 70, 80, 90)