Search code examples
pythoniteratorpython-multiprocessing

Python multiprocessing imap iterates over whole itarable


In my code I am trying to achieve the following:

  1. I get each result as soon as any of the processes finish
  2. Next iteration must only be called whenever it is necessary (if it is converted into a list, I will have RAM issues)

To my knowledge imap from multiprocessing module should be perfect for this task, but this code:

import os
import time

def txt_iterator():
    for i in range(8): 
        yield i
        print('iterated', i)

def func(x):
    time.sleep(5)
    return x

if __name__ == '__main__':
    import multiprocessing
    pool = multiprocessing.Pool(processes=4)
    for i in pool.imap( func, txt_iterator() ):
        print('P2', i)
    pool.close()

Has this output:

iterated 0
iterated 1
...
iterated 7
# 5 second pause
P2 0
P2 1
P2 2
P2 3
# 5 second pause
P2 4
P2 5
P2 6
P2 7

Meaning that it iterates through the whole iterable and only then starts assigning tasks to processes. As far as I could find in the docs, this behavior is only expected from .map (the iteration part).

The expected output is (may vary because they run concurrently, but you get the idea):

iterated 0
...
iterated 3
# 5 second pause
P2 0
...
P2 3
iterated 4
...
iterated 7
# 5 second pause
P2 4
...
P2 7

I am sure that I am missing something here but in case I completely misunderstand how this function works, I would appreciate any alternative that will work as intended.


Solution

  • imap doesn't guarantee consuming the input iterator at the same pace the workers finish their tasks.

    You can use a threading.BoundedSemaphore (even if you're only using a single thread) to have the input generator wait until the for loop has consumed an item:

    import multiprocessing
    import threading
    import time
    
    
    def txt_iterator(sem: threading.BoundedSemaphore):
        for i in range(30):
            sem.acquire()
            yield i
            print("iterated", i)
    
    
    def func(x):
        print("starting work on", x)
        time.sleep(1)
        return x
    
    
    if __name__ == "__main__":
        sem = threading.BoundedSemaphore(4)
        pool = multiprocessing.Pool(processes=4)
        for i in pool.imap(func, txt_iterator(sem)):
            sem.release()
            print("P2", i)
        pool.close()