Search code examples
pythonmultithreadingpicklepython-multiprocessing

Can't pickle local object when multiprocessing with pool.map


I am trying to use multiprocessing with the python Pool function, using functools.partial to input several arguments with constant value into the pool.map command (i.e the first argument is the only varying).

The issue is that when I run the code I get the following error and I don't know how why or how to solve it:

AttributeError: Can't pickle local object 
    'get_MAX_SNR_for_eventdata_file.<locals>.get_SNR_multiprocess'

I don't know why it cannot pickle an object. This is the code (which is inside a top-level function):

def get_SNR_multiprocess(binning, event_data, energy_interval, tstart, tstop, trigger_time):
    """ This function just changes the order of arguments to be able to use partial"""
    SNR=get_max_SNR_est(event_data, energy_interval, binning, tstart, tstop, trigger_time)
    return SNR

pool=multiprocessing.Pool(processes=4)

for i in range(len(energybands)-1):
    energy_interval=[energybands[i],energybands[i+1]]
    partial_func=partial(get_SNR_multiprocess, event_data=event_data, 
                         energy_interval=energy_interval, tstart=tstart, tstop=tstop, 
                         trigger_time=trigger_time)
    SNRlist=pool.map(partial_func,timescales)
pool.close()

I get a hint that the problem might have to do with the fact that only functions defined at top-level of a module can be pickled, according to What can be pickled?. However, I cannot figure out exactly the problem in my code, or how to solve it.

The function get_max_SNR_est in the code is a function defined in the same script and returns a value. This function is dependant on other function of the same script (which depends on another and so on...).

Just FYI, the code works without multiprocessing using a for loop, like:

SNRlist=[]
for i in range(len(energybands)-1):
    energy_interval=[energybands[i],energybands[i+1]]
    for binning in timescales:
        SNR=get_max_SNR_est(event_data, energy_interval, binning, tstart, tstop, 
                            trigger_time)
        SNRlist.append(SNR)

Edit: I forgot to put that the code that I am showing here is already in a function. Based on @martineau's comment I took the function get_SNR_multiprocessing out of the aforementioned function, which solves the issue of pickling (see answer).


Solution

  • Thanks to @martineau's comment I found a solution for this issue. As I mentioned later in the edit of the questions, the code that I am showing here is already in a function. I took the function get_SNR_multiprocessing out of the aforementioned function, which solves the issue of pickling. The new code (where I am showing the function that contained the code shown above) looks like this:

    def get_SNR_multiprocess(binning, event_data, energy_interval, tstart, tstop, trigger_time):
        """ This function just changes the order of arguments to be able to use functools.partial for multiprocessing"""
        SNR=get_max_SNR_est(event_data, energy_interval, binning, tstart, tstop, trigger_time)
        return SNR
    
    def get_MAX_SNR_for_eventdata_file(event_data, energybands, timescales, tstart, tstop, trigger_time):
        """
        Gives the maximum SNR of all timescales and energybands given for a given event_data file
        """ 
        SNRlist=[]
        for i in range(len(energybands)-1):
            energy_interval=[energybands[i],energybands[i+1]]
    
            with multiprocessing.Pool(processes=4) as pool:
                partial_func=partial(get_SNR_multiprocess, event_data=event_data, energy_interval=energy_interval, tstart=tstart, tstop=tstop, trigger_time=trigger_time)
                SNRlist=pool.map(partial_func,timescales)
    
    

    Unfortunately, this method takes the same time as the original method with a for loop.