Search code examples
pythonwhile-loopparallel-processingmultiprocessingpython-multiprocessing

Parallel while loop with unknown number of calls


I have written a function that does some calculations, and that everytime it is called it returns a different result, since it uses a different seed for the random number generator. In general, I want to run this functions many times, in order to obtain many samples.

I have managed to use multiprocessing to run this function in parallel, using, for example, 4 processes, until the desired number of runs n_runs is reached. Here is a minimal working example (note that the function flip_coin() is just an example function that uses a rng, in reality I am using a more complex function):

import multiprocessing as mp
import random, sys

def flip_coin(n):    
    # initialise the random number generator
    seed = random.randrange(sys.maxsize)
    rng = random.Random(seed)
    # do stuff and obtain result
    if rng.random()>0.5: res = 1
    else: res = 0
    return res, seed

# total number of runs
n_runs = 100
# initialise parallel pool
pool = mp.Pool(processes = 4)
# initialise empty lists for results
results, seeds = [], []
for result in pool.map(flip_coin, range(n_runs)):
    # save result and the seed that generated that result
    results.append(result[0])
    seeds.append(result[1])
# close parallel pool
pool.close(); pool.join() 

Now, instead of fixing the n_runs a priori, I would like to fix a different condition that is met only after an unknown number of calls to the function. For example, I would like to fix the number of 1's returned by the function. Without using multiprocessing, I would do something like this:

# desired number of ones
n_ones = 10
# counter to keep track of the ones
counter = 0
# empty list for seeds
seeds = []
while counter < n_ones:
    result = flip_coin(1)
    # if we got a 1, increase counter and save seed
    if result[0] == 1: 
        counter += 1
        seeds.append(result[1])

The question is: how do I parallelise such a while loop?


Solution

  • See my comment concerning whether multiprocessing is even advantageous. But this is how I would do it. The idea is that once you have been returned enough 1 values, you (implicitly) call terminate on the pool:

    import multiprocessing as mp
    import random, sys
    
    def flip_coin(n):
        # initialise the random number generator
        seed = random.randrange(sys.maxsize)
        rng = random.Random(seed)
        # do stuff and obtain result
        if rng.random()>0.5: res = 1
        else: res = 0
        return res, seed
    
    def main():
        def generate_run_numbers():
            # An inifnite supply:
            n = 0
            while True:
                n += 1
                yield n
    
        n_ones = 10
        counter = 0
        results, seeds = [], []
        with mp.Pool(processes = 4) as pool:
            for res, seed in pool.imap_unordered(flip_coin, generate_run_numbers()):
                results.append(res)
                seeds.append(seed)
                if res == 1:
                    counter += 1
                    if counter == n_ones:
                        break
        # This results in a call to pool.terminate(), which
        # kills all pool processes and no more tasks will be processed
        print(len(results), 'invocations were required.')
    
    if __name__ == '__main__':
        main()
    

    Prints:

    16 invocations were required.
    

    Another way is to use method apply_async with a callback. Here we use a semaphore to throttle task submission so with a given POOL_SIZE value there will never be more than POOL_SIZE tasks sitting on the task queue waiting to be executed. In this way we do not execute more tasks than necessary and when a pool process finishes execution of a task it will not remain idle since there will always be a task sitting on the queue waiting to run.

    import multiprocessing as mp
    import random, sys
    from threading import BoundedSemaphore
    
    def flip_coin():
        # initialise the random number generator
        seed = random.randrange(sys.maxsize)
        rng = random.Random(seed)
        # do stuff and obtain result
        if rng.random()>0.5: res = 1
        else: res = 0
        return res, seed
    
    def main():
        def callback(return_value):
            nonlocal counter, stop
    
            semaphore.release()
            if not stop:
                res, seed = return_value
                results.append(res)
                seeds.append(seed)
                if res == 1:
                    counter += 1
                    if counter == n_ones:
                        stop = True
    
        POOL_SIZE = 4
        # Always have one POOL_SIZE tasks sitting in the task queue for each pool process so
        # when a pool process completes a submitted task there is always one
        # queued up task it can immedaitely run:
        stop = False
        semaphore = BoundedSemaphore(2 * POOL_SIZE)
        n_ones = 10
        counter = 0
        results, seeds = [], []
        pool = mp.Pool(processes = POOL_SIZE)
        while not stop:
            semaphore.acquire()
            pool.apply_async(flip_coin, callback=callback)
        pool.close()
        pool.join()
        print(f'{len(results)} invocations were required.')
    
    if __name__ == '__main__':
        main()