Search code examples
pythongeneratorpython-multiprocessing

Does this parallel generator have problems in its implementations?


I'm new to multiprocessing and I'm trying to define a parallel generator to solve my problem; but I have some questions:

  • is mp.get_context('spawn') the most adequate for the job, provided that it will run on a node that has Red Hat 8.4 as OS, about 20 CPUs and 2 GPUs?
  • am I sure that the spawned processes will be closed and their resources will be freed? Are there changes that can help with this issue?
  • can there be any race conditions in the locks that crash the process?
  • is there any other evident problem I have not thought of, or any way to improve the solution?

The problem

I have a dataset made of many very large objects. They can't all be loaded in memory at the same time. So I have to take them one at a time, load it (takes about 1 min, convert the file in a big numpy array) and then use it (takes about 2 min). Since the load operation goes through the CPU and the use operation goes through the GPU, I would like to run them in parallel to save some time; I would also like to spawn the load operation when necessary to keep a couple of items always ready for use (and not spawn every load process at the beginning, as Pool.map would do). Finally, it has to be a generator to fit with other parts of the code which are not part of this question.

Attempted solution

This is a general version of the solution I try to use.

import multiprocessing as mp
from itertools import cycle
from random import random
from time import sleep


# CPU-bound long operation with results that occupy lots of memory
def put(q, item):
    print("preparing", item)
    sleep(2 + random())
    q.put(item)


# A generator that keeps `length` items ready, calculated asynchronously
def queue_generator(inputs, length=2):
    ctx = mp.get_context('spawn')
    q = ctx.Queue()
    procs = []
    inputs = cycle(inputs)  # To not run out of inputs
    stop = None
    while stop is None:
        if len(procs) < length:
            p = ctx.Process(target=put, args=(q, next(inputs)))
            procs.append(p)
            p.start()
            continue
        try:
            stop = yield q.get(timeout=30)
        except queue.Empty:
            pass
        alive = []
        for p in procs:
            try:
                p.close()
            except ValueError:
                alive.append(p)
        procs = alive
    for p in procs:
        p.join()
    q.close()
    print("Queue is closed")
    yield


if __name__ == '__main__':
    g = queue_generator(list("abcd"), length=5)
    for _ in range(10):
        print("Got", next(g))
    print("No more requests")
    g.send(True)

Solution

  • Designing your own process pool may not be a good idea, instead you should rely on the existing implementation which should work just fine:

    from multiprocessing import Pool, Semaphore
    from time import sleep
    from random import random
    
    
    def work(i: int):
        sleep(1 + random())
        print(f"Processed value {i}")
        return i
    
    
    def loader(sem: Semaphore, length: int = 100):
        for i in range(length):
            sem.acquire()
            print(f"Reading value {i}")
            yield i
    
    
    if __name__ == "__main__":
        max_workers = 4
    
        sem = Semaphore(value=max_workers)
    
        with Pool(processes=max_workers) as p:
            for result in p.imap(work, loader(sem)):
                sem.release()
                print(f"Got result {result}")
    

    The idea is to use a multiprocessing.Pool to manage max_workers processes which will execute work in parallel. The Pool.imap() method is used to launch the concurrent mapping and returns a lazy iterator that yields values as they are processed (there is also Pool.imap_unordered() for when you don't care of the arrival order).

    Additionally, a semaphore is used to limit the number of items produced by the loader generator. Without it, this generator would eagerly consume all the items during the Pool.imap() call. For your application this would mean that the whole dataset is loader into memory, which is what you want to avoid.

    In this example, only work is executed in parallel but you could also load the items in parallel if you want. For that, you could introduce another process pool and a queue of maxsize = max_workers` between them.


    If you want to design your own thread pool you could start with something like this:

    from multiprocessing import Process, Queue
    from time import sleep
    from random import random
    
    
    def worker(iq: Queue, oq: Queue):
        while True:
            item = iq.get()
            sleep(1 + random()/5)
            print(f"Processed value {item}")
            oq.put(item)
    
    def producer(q: Queue, length: int = 100):
        for i in range(length):
            print(f"Loading value {i}")
            q.put(i)
    
    
    if __name__ == "__main__":
        max_workers = 4
    
        iq = Queue(maxsize=max_workers)
        oq = Queue(maxsize=max_workers)
    
        processes = [
            Process(target=worker, args=(iq, oq), daemon=True) for _ in range(max_workers)
        ]
    
        processes.append(Process(target=producer, args=(iq,)))
    
        for process in processes:
            process.start()
    
        while True:
            item = oq.get()
            print(f"got result {item}")
    

    This launches 4 concurrent processes executing the heavy work. A fifth process is responsible for loading the input data. Two queues of limited size are used to limit the rates of production and consumption so that memory usage remains controlled.

    I didn't include error handling nor stop conditions, which should be included for a robust implementation.


    A side note, you say that data loading is slow because you need to read and convert files to numpy array. An optimisation could be to directly store the numpy arrays to disk, thus avoiding the conversion step. Also, good serialisation formats for arrays exist (parquet, pyarrow).


    Other useful ressource, you can take a look at the source code of the PyTorch Dataloader class (and more specifically the _MultiProcessingDataLoaderIter class). The functionality of this class is exactly the same the one you are designing: the parallel loading of data with pre-fetching.