Search code examples
pythonparallel-processingpickle

Python: _pickle.PicklingError: Can't pickle <function <lambda>>


I'm running Python 3.9.1

Note: I know there are questions with similar titles. But those questions are embedded in complicated code that makes it hard to understand the problem. This is a bare-bones implementation of the problem which I think others will find easier to digest.

EDIT: I have Pool(processes=64) in my code. But most other will probably have to change this according to how many cores there are on their computer. And if it takes too long, change listLen to a smaller number

I'm trying to learn about multiprocessing in order to solve a problem at work. I have a list of arrays with which I need to do a pairwise comparison of the arrays. But for simplicity, I've recreated the gist of the problem with simple integers instead of arrays and an addition function instead of a call to some complicated comparison function. With the code below, I'm running into the titular error

import time
from multiprocessing import Pool
import itertools
import random

def add_nums(a, b):
    return(a + b)

if __name__ == "__main__":
    listLen = 1000
    
    # Create a list of random numbers to do pairwise additions of
    myList = [random.choice(range(1000)) for i in range(listLen)]
    # Create a list of all pairwise combinations of the indices of the list
    index_combns = [*itertools.combinations(range(len(myList)),2)]

    # Do the pairwise operation without multiprocessing
    start_time = time.time()
    sums_no_mp = [*map(lambda x: add_nums(myList[x[0]], myList[x[1]]), index_combns)]
    end_time = time.time() - start_time
    print(f"Process took {end_time} seconds with no MP")

    # Do the pairwise operations with multiprocessing
    start_time = time.time()
    pool = Pool(processes=64)
    sums_mp = pool.map(lambda x: add_nums(myList[x[0]], myList[x[1]]), index_combns)
    end_time = time.time() - start_time
    print(f"Process took {end_time} seconds with MP")

    pool.close()
    pool.join()


Solution

  • I'm not exactly sure why (though a thorough read through the multiprocessing docs would probably have an answer), but there's a pickling process involved in python's multiprocessing where child processes are passed certain things. While I would have expected the lambdas to be inherited and not passed via pickle-ing, I guess that's not what's happening.

    Following the discussion in the comments, consider something like this approach:

    import time
    from multiprocessing import Pool
    import itertools
    import numpy as np
    from multiprocessing import shared_memory
    
    def add_mats(a, b):
        #time.sleep(0.00001)
        return (a + b)
    
    # Helper for mp version
    def add_mats_shared(shm_name, array_shape, array_dtype, i1, i2):
        shm = shared_memory.SharedMemory(name=shm_name)
        stacked = np.ndarray(array_shape, dtype=array_dtype, buffer=shm.buf)
        a = stacked[i1]
        b = stacked[i2]
        result = add_mats(a, b)
        shm.close()
        return result
    
    if __name__ == "__main__":
        class Timer:
            def __init__(self):
                self.start = None
                self.stop  = None
                self.delta = None
    
            def __enter__(self):
                self.start = time.time()
                return self
    
            def __exit__(self, *exc_args):
                self.stop = time.time()
                self.delta = self.stop - self.start
    
        arrays = [np.random.rand(5,5) for _ in range(50)]
        index_combns = list(itertools.combinations(range(len(arrays)),2))
    
        # Helper for non-mp version
        def add_mats_pair(ij_pair):
            i, j = ij_pair
            a = arrays[i]
            b = arrays[j]
            return add_mats(a, b)
    
        with Timer() as t:
            # Do the pairwise operation without multiprocessing
            sums_no_mp = list(map(add_mats_pair, index_combns))
    
        print(f"Process took {t.delta} seconds with no MP")
    
    
        with Timer() as t:
            # Stack arrays and copy result into shared memory
            stacked = np.stack(arrays)
            shm = shared_memory.SharedMemory(create=True, size=stacked.nbytes)
            shm_arr = np.ndarray(stacked.shape, dtype=stacked.dtype, buffer=shm.buf)
            shm_arr[:] = stacked[:]
    
            with Pool(processes=32) as pool:
                processes = [pool.apply_async(add_mats_shared, (
                    shm.name,
                    stacked.shape,
                    stacked.dtype,
                    i,
                    j,
                )) for (i,j) in index_combns]
                sums_mp = [p.get() for p in processes]
    
            shm.close()
            shm.unlink()
    
        print(f"Process took {t.delta} seconds with MP")
    
        for i in range(len(sums_no_mp)):
            assert (sums_no_mp[i] == sums_mp[i]).all()
    
        print("Results match.")
    

    It uses multiprocessing.shared_memory to share a single numpy (N+1)-dimensional array (instead of a list of N-dimensional arrays) between the host process and child processes.

    Other things that are different but don't matter:

    • Pool is used as a context manager to prevent having to explicitly close and join it.
    • Timer is a simply context manager to time blocks of code.
    • Some of the numbers have been adjusted randomly
    • pool.map replaced with calls to pool.apply_async

    pool.map would be fine too, but you'd want to build the argument list before the .map call and unpack it in the worker function, e.g.:

    with Pool(processes=32) as pool:
        args = [(
            shm.name,
            stacked.shape,
            stacked.dtype,
            i,
            j,
        ) for (i,j) in index_combns]
        sums_mp = pool.map(add_mats_shared, args)
    
    # and 
    
    # Helper for mp version
    def add_mats_shared(args):
        shm_name, array_shape, array_dtype, i1, i2 = args
        shm = shared_memory.SharedMemory(name=shm_name)
        ....