Search code examples
kotlinparallel-processingforkjoinpool

How to change my helper function so that is collects the results of the parallel processing tasks


I wrote this helper function, so that I can easily process a list in parallel and only continue code execution when all the work is done. It works nicely when you don't need to return a result.

(I know it isn't the best practice to create new pools every time, it can be easily moved out, but I wanted to keep the examples simple.)

fun recursiveAction(action: () -> Unit): RecursiveAction {
    return object : RecursiveAction() {
        override fun compute() {
            action()
        }
    }
}

fun <T> List<T>.parallelForEach(parallelSize: Int, action: (T) -> Unit) {
    ForkJoinPool(parallelSize).invoke(recursiveAction {
        this.parallelStream().forEach { action(it) }
    })
}

Example use:

val myList: List<SomeClass> [...]
val parallelSize: Int = 8

myList.parallelForEach(parallelSize) { listElement ->
   //Some task here
}

Is there any way to make a similar helper construct for when you want to collect the results back into a list?

I know I have to use a RecursiveTask instead of the RecursiveAction, but I couldn't manage to write a helper function like I had above to wrap it.

I'd like to use it like this:

val myList: List<SomeClass> [...]
val parallelSize: Int = 8

val result: List<SomeClass> = myList.parallelForEach(parallelSize) { listElement ->
   //Some task here
}

Alternatively, is there a simpler way to do this alltogether?


Solution

  • Answered by JeffMurdock over on Reddit

    fun <T> recursiveTask(action: () -> T): RecursiveTask<T> {
        return object : RecursiveTask<T>() {
            override fun compute(): T {
                return action()
            }
        }
    }
    
    fun <T, E> List<T>.parallelForEach(parallelSize: Int, action: (T) -> E): List<E> {
        val pool = ForkJoinPool(parallelSize)
        val result = mutableListOf<ForkJoinTask<E>>()
        for (item in this) {
            result.add(pool.submit(recursiveTask {
                action(item)
            }))
        }
        return result.map { it.join() }
    }
    
    fun main(args: Array<String>) {
        val list = listOf(1, 2, 3)
        list.parallelForEach(3) { it + 2 }.forEach { println(it) }
    }