Search code examples
scalafoldfoldleft

Scala foldLeft while some conditions are true


How to emulate following behavior in Scala? i.e. keep folding while some certain conditions on the accumulator are met.

def foldLeftWhile[B](z: B, p: B => Boolean)(op: (B, A) => B): B

For example

scala> val seq = Seq(1, 2, 3, 4)
seq: Seq[Int] = List(1, 2, 3, 4)
scala> seq.foldLeftWhile(0, _ < 3) { (acc, e) => acc + e }
res0: Int = 1
scala> seq.foldLeftWhile(0, _ < 7) { (acc, e) => acc + e }
res1: Int = 6

UPDATES:

Based on @Dima answer, I realized that my intention was a little bit side-effectful. So I made it synchronized with takeWhile, i.e. there would be no advancement if the predicate does not match. And add some more examples to make it clearer. (Note: that will not work with Iterators)


Solution

  • First, note that your example seems wrong. If I understand correctly what you describe, the result should be 1 (the last value on which the predicate _ < 3 was satisfied), not 6

    The simplest way to do this is using a return statement, which is very frowned upon in scala, but I thought, I'd mention it for the sake of completeness.

    def foldLeftWhile[A, B](seq: Seq[A], z: B, p: B => Boolean)(op: (B, A) => B): B = foldLeft(z) { case (b, a) => 
       val result = op(b, a) 
       if(!p(result)) return b
       result
    }
    

    Since we want to avoid using return, scanLeft might be a possibility:

    seq.toStream.scanLeft(z)(op).takeWhile(p).last
    

    This is a little wasteful, because it accumulates all (matching) results. You could use iterator instead of toStream to avoid that, but Iterator does not have .last for some reason, so, you'd have to scan through it an extra time explicitly:

     seq.iterator.scanLeft(z)(op).takeWhile(p).foldLeft(z) { case (_, b) => b }