Search code examples
pythonmultiprocessingglobal-variables

How can global variables be accessed when using Multiprocessing and Pool?


I'm trying to avoid having to pass variables redundantly into dataList (e.g. [(1, globalDict), (2, globalDict), (3, globalDict)]) and use them globally instead. global globalDict is not a solution to do so in the following code, however.

Is there a straightforward way to access data in a multiprocessing function globally?

I read the following here:

"Communication is expensive. In contrast to communication between threads, exchanging data between processes is much more expensive. In Python, the data is pickled in to binary format before transferring on pipes. Hence, the overhead of communication can be very significant when the task is small. To reduce the extraneous cost, better assign tasks in chunk."

I'm not sure if that would apply here, but I would like to simplify data access in any case.

def MPfunction(data):
    global globalDict

    data += 1

    # use globalDict

    return data

if __name__ == '__main__':

    pool = mp.Pool(mp.cpu_count())

    try:
        globalDict = {'data':1}

        dataList = [0, 1, 2, 3]
        data = pool.map(MPfunction, dataList, chunksize=10)

    finally:
        pool.close()
        pool.join()
        pool.terminate()

Solution

  • On Linux, multiprocessing forks a new copy of the process to run a pool worker. The process has a copy-on-write view of the parent memory space. As long as you allocate globalDict before creating the pool, its already there. Notice that any changes to that dict stay in the child.

    On Windows, a new instance of python is created and the needed state is pickled/unpickled in the child. You can use an initializing function when you create the pool and copy there. That's one copy per child process which is better than once per item mapped.

    (as an aside, start the try block after creating the pool so that you don't reference a bad pool object if that's what raises the error)

    import platform
    
    def MPfunction(data):
        global globalDict
    
        data += 1
    
        # use globalDict
    
        return data
    
    if platform.system() == "Windows":
        def init_pool(the_dict):
            global globalDict
            globalDict = the_dict
    
    if __name__ == '__main__':
        globalDict = {'data':1}
    
        if platform.system() == "Windows":
            pool = mp.Pool(mp.cpu_count, init_pool(globalDict))
        else:
            pool = mp.Pool(mp.cpu_count())
    
        try:
            dataList = [0, 1, 2, 3]
            data = pool.map(MPfunction, dataList, chunksize=10)
        finally:
            pool.close()
            pool.join()