Search code examples
pythonmultiprocessingpython-multiprocessing

Unable to use pool.apply_async to aggregate results with multiprocessing


Let's say I have the following function:

def fetch_num():
    x = np.random.randint(low=0, high=1000000) # choose a number
    for i in range(5000000): # do some calculations
        j = i ** 2
    return x # return a result

This function picks a random number, then does some calculations, and returns it.

I would like to create a large list, containing all of these results. The catch is, that I don't want to process the same number twice, and I want to use multiprocessing to make that quicker.

I've tried the following code:

import multiprocessing as mp
from tqdm import tqdm
from parallelizing_defs import fetch_num
import os
os.system("taskset -p 0xff %d" % os.getpid())
if __name__ == '__main__':


    n = 10 # number of numbers that I want to gather

    def collect_result(result): # a callback function - only append if it is not in the results list
        if result not in results:
            results.append(result)
            pbar.update(1) # this is just for the fancy progress bar

    def error_callback(e):
        raise e

    pool = mp.Pool(6) # create 6 workers

    global results # initialize results list
    results = []
    pbar = tqdm(total=n) # initialize a progress bar
    while len(results) < n: # work until enough results have been accumulated
        pool.apply_async(fetch_num, args=(), callback=collect_result, error_callback=error_callback)
    pool.close() 
    pool.join()

Notes:

My problem is:

  • The loop doesn't stop, it goes on forever.
  • The iterations are not faster, it doesn't seem to be using more than one core.

I've tried a bunch of other configurations, but it doesn't seem to work. This sounds like a very common situation but I haven't been able to find an example of that particular problem. Any ideas as to why these behaviours take place would be much appreciated.


Solution

  • You have several issues. First, you need to include numpy. But your big problem is:

    while len(results) < n: # work until enough results have been accumulated
        pool.apply_async(fetch_num, args=(), callback=collect_result, error_callback=error_callback)
    

    You can be submitting these jobs with calls to apply_async faster than the results are returned and end up submitting way too many jobs. You need to submit exactly n jobs and take care of ensuring that duplicate results are not returned in fetch_num. The way to do that is to use a sharable set that holds all previously generated numbers. Unfortunately, sharable sets do not exist. But we do have sharable dictionaries that can serve the purpose. We therefore initialize each process in the pool with a proxy pointer to the sharable dictionary and a lock to serialize access to the dictionary.

    It is true that process pools functions sucg as fetch_num must be imported for but only if you are running under something like jupyter notebook. If you are running the program "normally" from the command line, this is not required. I have therefore included the source inline so you might see it. I have also added a print statement so that you can see that all 6 processes are running in parallel.

    import multiprocessing as mp
    import numpy as np
    from tqdm import tqdm
    
    
    def pool_init(the_dict, l):
        global num_set, the_lock
        num_set = the_dict
        the_lock = l
    
    
    def fetch_num():
        the_lock.acquire()
        print('fetch_num')
        while True:
            # get
            x = np.random.randint(low=0, high=1000000) # choose a number
            if x not in num_set:
                num_set[x] = True
                break
        the_lock.release()
    
        for i in range(5000000): # do some calculations
            j = i ** 2
        return x # return a result
    
    
    
    if __name__ == '__main__':
    
        with mp.Manager() as manager:
            the_dict = manager.dict()
            the_lock = mp.Lock()
            n = 10 # number of numbers that I want to gather
    
            results = []
            def collect_result(result):
                results.append(result)
                pbar.update(1) # this is just for the fancy progress bar
    
            pool = mp.Pool(6, initializer=pool_init, initargs=(the_dict, the_lock)) # create 6 workers
            pbar = tqdm(total=n) # initialize a progress bar
            for _ in range(n):
                pool.apply_async(fetch_num, args=(), callback=collect_result)
            pool.close()
            pool.join()
            print()
            print(results)