Search code examples
pythonmultiprocessingcontextmanager

Context managers and multiprocessing pools


Suppose you are using a multiprocessing.Pool object, and you are using the initializer setting of the constructor to pass an initializer function that then creates a resource in the global namespace. Assume resource has a context manager. How would you handle the life-cycle of the context managed resource provided it has to live through the life of the process, but be properly cleaned up at the end?

So far, I have something somewhat like this:

resource_cm = None
resource = None


def _worker_init(args):
    global resource
    resource_cm = open_resource(args)
    resource = resource_cm.__enter__()

From here on, the pool processes can use the resource. So far so good. But handling clean up is a bit trickier, since the multiprocessing.Pool class does not provide a destructor or deinitializer argument.

One of my ideas is to use the atexit module, and register the clean up in the initializer. Something like this:

def _worker_init(args):
    global resource
    resource_cm = open_resource(args)
    resource = resource_cm.__enter__()

    def _clean_up():
        resource_cm.__exit__()

    import atexit
    atexit.register(_clean_up)

Is this a good approach? Is there an easier way of doing this?

EDIT: atexit does not seem to work. At least not in the way I am using it above, so as of right now I still do not have a solution for this problem.


Solution

  • First, this is a really great question! After digging around a bit in the multiprocessing code, I think I've found a way to do this:

    When you start a multiprocessing.Pool, internally the Pool object creates a multiprocessing.Process object for each member of the pool. When those sub-processes are starting up, they call a _bootstrap function, which looks like this:

    def _bootstrap(self):
        from . import util
        global _current_process
        try:
            # ... (stuff we don't care about)
            util._finalizer_registry.clear()
            util._run_after_forkers()
            util.info('child process calling self.run()')
            try:
                self.run()
                exitcode = 0 
            finally:
                util._exit_function()
            # ... (more stuff we don't care about)
    

    The run method is what actually runs the target you gave the Process object. For a Pool process that's a method with a long-running while loop that waits for work items to come in over an internal queue. What's really interesting for us is what happened after self.run: util._exit_function() is called.

    As it turns out, that function does some clean up that sounds a lot like what you're looking for:

    def _exit_function(info=info, debug=debug, _run_finalizers=_run_finalizers,
                       active_children=active_children,
                       current_process=current_process):
        # NB: we hold on to references to functions in the arglist due to the
        # situation described below, where this function is called after this
        # module's globals are destroyed.
    
        global _exiting
    
        info('process shutting down')
        debug('running all "atexit" finalizers with priority >= 0')  # Very interesting!
        _run_finalizers(0)
    

    Here's the docstring of _run_finalizers:

    def _run_finalizers(minpriority=None):
        '''
        Run all finalizers whose exit priority is not None and at least minpriority
    
        Finalizers with highest priority are called first; finalizers with
        the same priority will be called in reverse order of creation.
        '''
    

    The method actually runs through a list of finalizer callbacks and executes them:

    items = [x for x in _finalizer_registry.items() if f(x)]
    items.sort(reverse=True)
    
    for key, finalizer in items:
        sub_debug('calling %s', finalizer)
        try:
            finalizer()
        except Exception:
            import traceback
            traceback.print_exc()
    

    Perfect. So how do we get into the _finalizer_registry? There's an undocumented object called Finalize in multiprocessing.util that is responsible for adding a callback to the registry:

    class Finalize(object):
        '''
        Class which supports object finalization using weakrefs
        '''
        def __init__(self, obj, callback, args=(), kwargs=None, exitpriority=None):
            assert exitpriority is None or type(exitpriority) is int
    
            if obj is not None:
                self._weakref = weakref.ref(obj, self)
            else:
                assert exitpriority is not None
    
            self._callback = callback
            self._args = args
            self._kwargs = kwargs or {}
            self._key = (exitpriority, _finalizer_counter.next())
            self._pid = os.getpid()
    
            _finalizer_registry[self._key] = self  # That's what we're looking for!
    

    Ok, so putting it all together into an example:

    import multiprocessing
    from multiprocessing.util import Finalize
    
    resource_cm = None
    resource = None
    
    class Resource(object):
        def __init__(self, args):
            self.args = args
    
        def __enter__(self):
            print("in __enter__ of %s" % multiprocessing.current_process())
            return self
    
        def __exit__(self, *args, **kwargs):
            print("in __exit__ of %s" % multiprocessing.current_process())
    
    def open_resource(args):
        return Resource(args)
    
    def _worker_init(args):
        global resource
        print("calling init")
        resource_cm = open_resource(args)
        resource = resource_cm.__enter__()
        # Register a finalizer
        Finalize(resource, resource.__exit__, exitpriority=16)
    
    def hi(*args):
        print("we're in the worker")
    
    if __name__ == "__main__":
        pool = multiprocessing.Pool(initializer=_worker_init, initargs=("abc",))
        pool.map(hi, range(pool._processes))
        pool.close()
        pool.join()
    

    Output:

    calling init
    in __enter__ of <Process(PoolWorker-1, started daemon)>
    calling init
    calling init
    in __enter__ of <Process(PoolWorker-2, started daemon)>
    in __enter__ of <Process(PoolWorker-3, started daemon)>
    calling init
    in __enter__ of <Process(PoolWorker-4, started daemon)>
    we're in the worker
    we're in the worker
    we're in the worker
    we're in the worker
    in __exit__ of <Process(PoolWorker-1, started daemon)>
    in __exit__ of <Process(PoolWorker-2, started daemon)>
    in __exit__ of <Process(PoolWorker-3, started daemon)>
    in __exit__ of <Process(PoolWorker-4, started daemon)>
    

    As you can see __exit__ gets called in all our workers when we join() the pool.