Search code examples
pythonmultiprocessingpython-multiprocessing

How to correctly handle exceptions in multiprocessing


It is possible to retrieve the outputs of workers with Pool.map, but when one worker fails, an exception is raised and it's not possible to retrieve the outputs anymore. So, my idea was to log the outputs in a process-synchronized queue so as to retrieve the outputs of all successful workers.

The following snippet seems to work:

from multiprocessing import Pool, Manager
from functools import partial

def f(x, queue):
    if x == 4:
        raise Exception("Error")

    queue.put_nowait(x)

if __name__ == '__main__':
    queue = Manager().Queue()
    pool = Pool(2)

    try:
        pool.map(partial(f, queue=queue), range(6))
        pool.close()
        pool.join()
    except:
        print("An error occurred")

    while not queue.empty():
        print("Output => " + str(queue.get()))

But I was wondering whether a race condition could occur during the queue polling phase. I'm not sure whether the queue process will necessarily be alive when all workers have completed. Do you think my code is correct from that point of view?


Solution

  • As far as "how to correctly handle exceptions", which is your main question:

    First, in your case, you will never get to execute pool.close and pool.join. But pool.map will not return until all the submitted tasks have returned their results or generated an exception, so you really don't need to call these to be sure that all of your submitted tasks have been completed. If it weren't for worker function f writing the results to a queue, you would never be able to get any results back using map as long as long as any of your tasks resulted in an exception. You would instead have to apply_async individual tasks and get AsyncResult instances for each one.

    So I would say that a better way of handling exceptions in you worker functions without having to resort to using a queue would be as follows. But note that when you use apply_async, tasks are being submitted one task at a time, which can result in many shared memory accesses. This becomes a performance issue really only when the number of tasks being submitted is very large. In this case, it would be better for worker functions to handle the exceptions themselves and somehow pass back the error indication to allow the use of map or imap, where you could specify a chunksize.

    When using a queue, be aware that writing to a managed queue has fair bit of overhead. The second piece of code shows how you can reduce that overhead a bit by using a multiprocessing.Queue instance, which does not use a proxy unlike the managed queue. Note the output order, which is not the order in which the tasks were submitted but rather the order in which tasks completed -- another potential downside or upside to using a queue (you can use a callback function with apply_async if you want the results in the order completed). Even with your original code you should not depend on the order of results in the queue.

    from multiprocessing import Pool
    
    def f(x):
        if x == 4:
            raise Exception("Error")
    
        return x
    
    if __name__ == '__main__':
        pool = Pool(2)
        results = [pool.apply_async(f, args=(x,)) for x in range(6)]
        for x, result in enumerate(results): # result is AsyncResult instance:
            try:
                return_value = result.get()
            except:
                print(f'An error occurred for x = {x}')
            else:
                print(f'For x = {x} the return value is {return_value}')
    

    Prints:

    For x = 0 the return value is 0
    For x = 1 the return value is 1
    For x = 2 the return value is 2
    For x = 3 the return value is 3
    An error occurred for x = 4
    For x = 5 the return value is 5
    

    OP's Original Code Modified to Use multiprocessing.Queue

    from multiprocessing import Pool, Queue
    
    
    def init_pool(q):
        global queue
        queue = q
    
    def f(x):
        if x == 4:
            raise Exception("Error")
    
        queue.put_nowait(x)
    
    if __name__ == '__main__':
        queue = Queue()
        pool = Pool(2, initializer=init_pool, initargs=(queue,))
    
        try:
            pool.map(f, range(6))
        except:
            print("An error occurred")
    
        while not queue.empty():
            print("Output => " + str(queue.get()))
    

    Prints:

    An error occurred
    Output => 0
    Output => 2
    Output => 3
    Output => 1
    Output => 5