Search code examples
pythonmultiprocessinglockingpython-multiprocessing

Run python multiprocessing with Lock()


i have multiprocessing script with "pool.imap_unordered".

I want to use

multiprocessing.Lock()

can you write the right way of using for the following script?

import multiprocessing

def my_func(df):

   # modify df here
   # ...
   # df = df.head(1)
   return df

if __name__ == "__main__":
    df = pd.DataFrame({'a': [2, 2, 1, 1, 3, 3], 'b': [4, 5, 6, 4, 5, 6], 'c': [4, 5, 6, 4, 5, 6]})
    with multiprocessing.Pool() as pool:
        groups = (g for _, g in df.groupby("a"))
        print(df)
        print(groups)
        out = []
        for res in pool.imap_unordered(my_func, groups):
            out.append(res)
    final_df = pd.concat(out)

Solution

  • If you wanted to use a lock to serialize updates to the df being passed to your worker function, my_func, then the code below is how you would do it.

    But the question is why would you want to do that? A lock might be required if multiple processes were modifying the same element using an operation that was not atomic where that element is sharable. That is, it only makes sense if a change made by one process was visible to another process. But that is clearly not the case here because the passed df argument is not such a sharable structure.

    import multiprocessing as mp
    import pandas as pd
    
    def init_pool(*args):
        global lock
    
        lock = args[0]
    
    def my_func(df):
        # Why are we using a lock here?
        with lock:
            # Modify df here
            ...
        return df
    
    
    if __name__ == "__main__":
        df = pd.DataFrame(
            {"a": [2, 2, 1, 1, 3, 3], "b": [4, 5, 6, 4, 5, 6], "c": [4, 5, 6, 4, 5, 6]}
        )
        lock = mp.Lock()
        with mp.Pool(initializer=init_pool, initargs=(lock,)) as pool:
            groups = (g for _, g in df.groupby("a"))
            out = pool.imap_unordered(my_func, groups)
            rdf = pd.concat(out)
            print('----------\n', rdf, sep='')
    

    Example Showing a Valid Use for a Lock

    In the following example we create a sharable integer value initialized to 0 and submit 3 tasks where each task is incrementing the value 1000 times. The final value should be 3000. If the increment operation v.value += 1 as not done under control of a lock, the final value may not be 3000 since the incrementing is not atomic: First the current value is fetched into some temporary, the temporary is incremented, and finally the incremented value is stored back. Only the use of a lock prevents two processes from fetching the same value and incrementing it to the same final value. Note that this multiprocessing.Value instance is created with its own internal lock that is obtained with the call v.get_lock():

    import multiprocessing as mp
    
    def init_pool(*args):
        global v
    
        v = args[0]
    
    def my_func():
        for _ in range(1000):
            with v.get_lock():
                v.value += 1
    
    if __name__ == "__main__":
        v = mp.Value('i', 0, lock=True)
        N_TASKS = 3
        pool = mp.Pool(min(N_TASKS, mp.cpu_count()), initializer=init_pool, initargs=(v,))
        for i in range(N_TASKS):
            pool.apply_async(my_func)
        # Wait for all tasks to complete:
        pool.close()
        pool.join()
        print(v.value)
    

    Prints:

    3000
    

    In the following example we create a sharable array of 3 integers. But since each task is incrementing a different element of the array, we do not require a lock:

    import multiprocessing as mp
    
    def init_pool(*args):
        global arr
    
        arr = args[0]
    
    def my_func(idx):
        for _ in range(1000):
            arr[idx] += 1
    
    if __name__ == "__main__":
        N_TASKS = 3
        arr = mp.Array('i', [0] * N_TASKS, lock=False)
        pool = mp.Pool(min(N_TASKS, mp.cpu_count()), initializer=init_pool, initargs=(arr,))
        pool.map(my_func, range(N_TASKS))
        pool.close()
        pool.join()
        # Convert to a list for easy printing:
        print(list(arr))
    

    Prints:

    [1000, 1000, 1000]