I have written a function that does some calculations, and that everytime it is called it returns a different result, since it uses a different seed for the random number generator. In general, I want to run this functions many times, in order to obtain many samples.
I have managed to use multiprocessing
to run this function in parallel, using, for example, 4 processes, until the desired number of runs n_runs
is reached. Here is a minimal working example (note that the function flip_coin()
is just an example function that uses a rng
, in reality I am using a more complex function):
import multiprocessing as mp
import random, sys
def flip_coin(n):
# initialise the random number generator
seed = random.randrange(sys.maxsize)
rng = random.Random(seed)
# do stuff and obtain result
if rng.random()>0.5: res = 1
else: res = 0
return res, seed
# total number of runs
n_runs = 100
# initialise parallel pool
pool = mp.Pool(processes = 4)
# initialise empty lists for results
results, seeds = [], []
for result in pool.map(flip_coin, range(n_runs)):
# save result and the seed that generated that result
results.append(result[0])
seeds.append(result[1])
# close parallel pool
pool.close(); pool.join()
Now, instead of fixing the n_runs
a priori, I would like to fix a different condition that is met only after an unknown number of calls to the function. For example, I would like to fix the number of 1
's returned by the function. Without using multiprocessing
, I would do something like this:
# desired number of ones
n_ones = 10
# counter to keep track of the ones
counter = 0
# empty list for seeds
seeds = []
while counter < n_ones:
result = flip_coin(1)
# if we got a 1, increase counter and save seed
if result[0] == 1:
counter += 1
seeds.append(result[1])
The question is: how do I parallelise such a while
loop?
See my comment concerning whether multiprocessing is even advantageous. But this is how I would do it. The idea is that once you have been returned enough 1 values, you (implicitly) call terminate
on the pool:
import multiprocessing as mp
import random, sys
def flip_coin(n):
# initialise the random number generator
seed = random.randrange(sys.maxsize)
rng = random.Random(seed)
# do stuff and obtain result
if rng.random()>0.5: res = 1
else: res = 0
return res, seed
def main():
def generate_run_numbers():
# An inifnite supply:
n = 0
while True:
n += 1
yield n
n_ones = 10
counter = 0
results, seeds = [], []
with mp.Pool(processes = 4) as pool:
for res, seed in pool.imap_unordered(flip_coin, generate_run_numbers()):
results.append(res)
seeds.append(seed)
if res == 1:
counter += 1
if counter == n_ones:
break
# This results in a call to pool.terminate(), which
# kills all pool processes and no more tasks will be processed
print(len(results), 'invocations were required.')
if __name__ == '__main__':
main()
Prints:
16 invocations were required.
Another way is to use method apply_async
with a callback. Here we use a semaphore to throttle task submission so with a given POOL_SIZE
value there will never be more than POOL_SIZE
tasks sitting on the task queue waiting to be executed. In this way we do not execute more tasks than necessary and when a pool process finishes execution of a task it will not remain idle since there will always be a task sitting on the queue waiting to run.
import multiprocessing as mp
import random, sys
from threading import BoundedSemaphore
def flip_coin():
# initialise the random number generator
seed = random.randrange(sys.maxsize)
rng = random.Random(seed)
# do stuff and obtain result
if rng.random()>0.5: res = 1
else: res = 0
return res, seed
def main():
def callback(return_value):
nonlocal counter, stop
semaphore.release()
if not stop:
res, seed = return_value
results.append(res)
seeds.append(seed)
if res == 1:
counter += 1
if counter == n_ones:
stop = True
POOL_SIZE = 4
# Always have one POOL_SIZE tasks sitting in the task queue for each pool process so
# when a pool process completes a submitted task there is always one
# queued up task it can immedaitely run:
stop = False
semaphore = BoundedSemaphore(2 * POOL_SIZE)
n_ones = 10
counter = 0
results, seeds = [], []
pool = mp.Pool(processes = POOL_SIZE)
while not stop:
semaphore.acquire()
pool.apply_async(flip_coin, callback=callback)
pool.close()
pool.join()
print(f'{len(results)} invocations were required.')
if __name__ == '__main__':
main()