Search code examples
kotlinsequencecoroutinekotlinx.coroutines

Is this implementation of takeWhileInclusive safe?


I found the following implementation of an inclusive takeWhile (found here)

fun <T> Sequence<T>.takeWhileInclusive(pred: (T) -> Boolean): Sequence<T> {
    var shouldContinue = true
    return takeWhile {
        val result = shouldContinue
        shouldContinue = pred(it)
        result
    }
}

The problem is I'm not 100% convinced this is safe if used on a parallel sequence.

My concern is that we'd be relying on the shouldContinue variable to know when to stop, but we're not synchronizing it's access.

Any insights?


Solution

  • Here's what I've figured out so far.

    Question clarification

    The question is unclear. There's no such thing as a parallel sequence I probably got them mixed up with Java's parallel streams. What I meant was a sequence that was consumed concurrently.

    Sequences are synchronous

    As @LouisWasserman pointed out in the comments sequences are not designed for parallel execution. In particular the SequenceBuilder is annotated with @RestrictSuspension. Citing from Kotlin Coroutine repo:

    It means that no SequenceBuilder extension of lambda in its scope can invoke suspendContinuation or other general suspending function

    Having said that as @MarkoTopolnik commented they can still be used in a parallel program just like any other Object.

    Sequences used in parallel

    As an example here's a first attempt of using Sequences in parallel

    fun launchProcessor(id: Int, iterator: Iterator<Int>) = launch {
        println("[${Thread.currentThread().name}] Processor #$id received ${iterator.next()}")
    }
    
    fun main(args: Array<String>) {
        val s = sequenceOf(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)
        runBlocking {
            val iterator = s.iterator()
            repeat(10) { launchProcessor(it, iterator) }
        }
    }
    

    This code prints:

    [ForkJoinPool.commonPool-worker-2] Processor #1 received 1

    [ForkJoinPool.commonPool-worker-1] Processor #0 received 0

    [ForkJoinPool.commonPool-worker-3] Processor #2 received 2

    [ForkJoinPool.commonPool-worker-2] Processor #3 received 3

    [ForkJoinPool.commonPool-worker-1] Processor #4 received 3

    [ForkJoinPool.commonPool-worker-3] Processor #5 received 3

    [ForkJoinPool.commonPool-worker-1] Processor #7 received 5

    [ForkJoinPool.commonPool-worker-2] Processor #6 received 4

    [ForkJoinPool.commonPool-worker-1] Processor #9 received 7

    [ForkJoinPool.commonPool-worker-3] Processor #8 received 6

    Which of course is not what we want. As some numbers are consumed twice.

    Enter channels

    On the other hand if we were to use channels we could write something like this:

    fun produceNumbers() = produce {
        var x = 1 // start from 1
        while (true) {
            send(x++) // produce next
            delay(100) // wait 0.1s
        }
    }
    
    fun launchProcessor(id: Int, channel: ReceiveChannel<Int>) = launch {
        channel.consumeEach {
            println("[${Thread.currentThread().name}] Processor #$id received $it")
        }
    }
    
    fun main(args: Array<String>) = runBlocking<Unit> {
        val producer = produceNumbers()
        repeat(5) { launchProcessor(it, producer) }
        delay(1000)
        producer.cancel() // cancel producer coroutine and thus kill them all
    }
    

    Then the output is:

    [ForkJoinPool.commonPool-worker-2] Processor #0 received 1

    [ForkJoinPool.commonPool-worker-2] Processor #0 received 2

    [ForkJoinPool.commonPool-worker-1] Processor #1 received 3

    [ForkJoinPool.commonPool-worker-2] Processor #2 received 4

    [ForkJoinPool.commonPool-worker-1] Processor #3 received 5

    [ForkJoinPool.commonPool-worker-2] Processor #4 received 6

    [ForkJoinPool.commonPool-worker-2] Processor #0 received 7

    [ForkJoinPool.commonPool-worker-1] Processor #1 received 8

    [ForkJoinPool.commonPool-worker-1] Processor #2 received 9

    [ForkJoinPool.commonPool-worker-2] Processor #3 received 10

    Furthermore we could implement the takeWhileInclusive method for channels like this:

    fun <E> ReceiveChannel<E>.takeWhileInclusive(
            context: CoroutineContext = Unconfined,
            predicate: suspend (E) -> Boolean
    ): ReceiveChannel<E> = produce(context) {
        var shouldContinue = true
        consumeEach {
            val currentShouldContinue = shouldContinue
            shouldContinue = predicate(it)
            if (!currentShouldContinue) return@produce
            send(it)
        }
    }
    

    And it works as expected.