Search code examples
pythonmultithreadingmultiprocessingrandom-seedreproducible-research

Reproducibility with multithreading and multiprocessing in Python (how to fix random seed)


My code does the following:

  • Starts processes to collect data
  • Starts processes to test model
  • One thread takes care of training (read data from collect processes)
  • One thread takes care of testing (read data from test processes)
  • Every time the training thread does a step, it waits for the testing to also do one step
  • Before doing a step, the testing thread waits for a training step

I need to have reproducible results, but there is randomness in both the processes and the threads. I naively fix the seeds in each process and thread, but results are always different.

Is it possible to have reproducible results? I know threads are non-deterministic, but I don't launch multiple threads from the same pool: I have 2 pools, each launching only 1 thread.

Below is a simple MWE. I need the output to be always the same every time I run this program.

EDIT

Using the initializer argument in all pools I can have deterministic behavior within threads and processes. However, the order in which processes write the data is random due to multiprocesses non-deterministic behavior. Sometimes one process reads the queue first and writes, sometimes it's another process.

How can I fix it?

import logging
import traceback
import torch
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import ProcessPoolExecutor
from torch import multiprocessing as mp

shandle = logging.StreamHandler()
log = logging.getLogger('rl')
log.propagate = False
log.addHandler(shandle)
log.setLevel(logging.INFO)


def fix_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def collect(id, queue, data):
    #log.info('Collect %i started ...', id)
    while True:
        idx = queue.get()
        if idx is None:
            break
        data[idx] = torch.rand(1)
        log.info(f'Collector {id} got idx {idx} and sampled {data[idx]}')
        queue.task_done()
    #log.info('Collect %i completed', id)


def test(id, queue, data):
    #log.info('Test %i started ...', id)
    while True:
        idx = queue.get()
        if idx is None:
            break
        data[idx] = torch.rand(1)
        log.info(f'Tester {id} got idx {idx} and sampled {data[idx]}')
        queue.task_done()
    #log.info('Test %i completed', id)


def run():
    steps = 0
    num_collect_procs = 3
    num_test_procs = 2
    max_steps = 10

    data_collect = torch.zeros(num_collect_procs).share_memory_()
    data_test = torch.zeros(num_test_procs).share_memory_()

    ctx = mp.get_context('spawn')
    manager = mp.Manager()
    collect_queue = manager.JoinableQueue()
    test_queue = manager.JoinableQueue()
    train_test_queue = manager.JoinableQueue()

    collect_pool = ProcessPoolExecutor(
        num_collect_procs,
        mp_context=ctx,
        initializer=fix_seed,
        initargs=(1,)
    )
    test_pool = ProcessPoolExecutor(
        num_test_procs,
        mp_context=ctx,
        initializer=fix_seed,
        initargs=(1,)
    )

    for i in range(num_collect_procs):
        future = collect_pool.submit(collect, i, collect_queue, data_collect)

    for i in range(num_test_procs):
        future = test_pool.submit(test, i, test_queue, data_test)

    def run_train():
        nonlocal steps
        #log.info('Training thread started ...')
        while steps < max_steps:
            train_test_queue.put(True)
            train_test_queue.join()
            for idx in range(num_collect_procs):
                collect_queue.put(idx)
            log.info('Training, %i %f', steps, data_collect.sum() + torch.rand(1))
            collect_queue.join()
            steps += 1
        #log.info('Training ended')
        for i in range(num_collect_procs):
            collect_queue.put(None)
        train_test_queue.put(None)

    def run_test():
        nonlocal steps
        #log.info('Testing thread started ...')
        while steps < max_steps:
            status = train_test_queue.get()
            if status is None:
                break
            for idx in range(num_test_procs):
                test_queue.put(idx)
            log.info('Testing, %i %f', steps, data_test.sum() + torch.rand(1))
            test_queue.join()
            train_test_queue.task_done()
        #log.info('Testing ended')
        for i in range(num_test_procs):
            test_queue.put(None)

    training_thread = ThreadPoolExecutor(1, initializer=fix_seed, initargs=(1,))
    testing_thread = ThreadPoolExecutor(1, initializer=fix_seed, initargs=(1,))
    training_thread.submit(run_train)
    testing_thread.submit(run_test)


if __name__ == '__main__':
    run()


Solution

  • I am not familiar with torch and I could not easily tell whether its random number generator is sharable across all processes or whether each process has its own generator that will generate the same sequence of numbers if they are both seeded identically.

    Let's first assume the generator is sharable, i.e. each process is effectively making calls to the same, sharable random number generator seeded with 0, and that the first two random numbers generated for such a sequence are 9 and 11. Let's assume you have only two collect processes, p1 and p2. When you run the program the first time, this is the order of events:

    1. p1 retrieves idx value of 0 from queue
    2. p1 gets the first random number, 9, and assigns data[0] = 9
    3. p2 retrieves idx value of 1 from queue
    4. p2 gets the second random number, 11, and stores data[1] = 11

    The next time you run, this is the order that events occur:

    1. p1 retrieves idx value of 0 from queue and then loses control of the CPU before it has a chance to get a random number and store it
    2. p2 retrieves idx value of 1 from queue
    3. p2 gets the first random number, 9, and stores data[1] = 9
    4. p1 gets the second random number, 11, and stores data[0] = 11

    Already we see the results are not duplicated. The only way to ensure duplication would be to serialize all the code between idx = queue.get() and data[idx] = torch.rand(1) with a multiprocessing.Semaphore so that you guarantee any process that retrieves the Nth index is also retrieving the Nth random number from the seeded sequence. Assuming that you are doing "real" work in your actual code and that the results for a given index only depends on the random number used, then this should be doable without any performance impact. You would allocate a semaphore and use the initializer and initargs arguments to initialize each pool process with the semaphore and you would place the previously described critical section within a with semaphore: block:

    def init_pool_processes(sem):
        global semaphore
    
        semaphore = sem
    
    def collect(...):
        ...
        while True:
            with semaphore:
                idx = queue.get()
                if idx is None:
                    break
                 random_number = torch.rand(1)
            # do real work in the following function call:
            data[idx] = some_function_of(random_number)
        ...
    
        collect_pool = ProcessPoolExecutor(num_collect_procs,
                                           mp_context=ctx,
                                           initializer=init_pool_processes,
                                           initargs=(multprocessing.Semaphore(),)
                                           )
    

    Let's repeat the same two scenarios where each process has its own random number generator:

    First run:

    1. p1 retrieves idx value of 0 from queue
    2. p1 gets the first random number, 9, and assigns data[0] = 9
    3. p2 retrieves idx value of 1 from queue
    4. p2 gets the first random number, 9, and stores data[1] = 9

    Second run:

    1. p1 retrieves idx value of 0 from queue and then loses control of the CPU before it has a chance to get a random number and store it
    2. p2 retrieves idx value of 1 from queue
    3. p2 gets the first random number, 9, and stores data[1] = 9
    4. p1 gets the second random number, 11, and stores data[0] = 11

    The only possible way that I can see of ensuring duplicate runs is if you split all the possible indices into two groups and you have two input queues. You pass to p1 one of the queues to which you put half the indices and you pass to p2 the other input queue to which you have written the remaining indices. That ensures that the random number used for any index does not vary from run to run, i.e. the Nth index retrieved by a given process will be using the Nth random number. In this case you should seed each process differently to avoid computing the same results for different indices.

    Update

    Your logic seems to be overly complicated with pools, which have their own internal queues, processes, threads, queues, etc. Frankly, I am having some difficulty in following what you are doing. But the following is my idea of achieving repeatability of results. Here I assume that the random number generator is not sharable across processes and I use an implementation for which I know that to be the case. Therefore, I create N processes and each one seeds its own random number generator uniquely to minimize the probability of multiple processes generating the same random number (which wouldn't be fatal if that were to happen, I assume). A shared array initialized to zeros is created and passed to each process. But each process has its own input queue from which it is retrieving indices and I write 1/N of the indices to each queue so that each will repeatedly generate the ith random number in the sequence its seeded random number generator will create when it retrieves the ith index that has been passed to it:

    import random
    from multiprocessing import Queue, Process, Array
    
    N_PROCESSES = 4
    INDICES_PER_PROCESS = 5
    
    def worker(seed, queue, array):
        # Use a different seed for each process:
        random.seed(seed)
    
        while True:
            idx = queue.get()
            if idx is None:
                break
            array[idx] = random.randint(0, 100_000_000)
    
    def run_trial():
        array = Array('i', INDICES_PER_PROCESS * N_PROCESSES, lock=False)
        processes = []
    
        idx = 0
        for seed in range(N_PROCESSES):
            queue = Queue()
            p = Process(target=worker, args=(seed, queue, array))
            p.start()
            processes.append(p)
            for _ in range(INDICES_PER_PROCESS):
                queue.put(idx)
                idx += 1
            queue.put(None)
    
        for p in processes:
            p.join()
    
        return array
    
    
    
    def run_trials():
        array1 = run_trial()
        l1 = list(array1) # For easy comparison
    
        array2 = run_trial()
        l2 = list(array2)
    
        print('l1 =', l1)
        print('l2 =', l2)
        assert l1 == l2
    
    if __name__ == '__main__':
        run_trials()
    

    Prints:

    l1 = [51706749, 56448162, 5433721, 34751217, 68622131, 18034063, 76397250, 8470054, 34234785, 15826780, 7590196, 12292302, 11391326, 48460313, 22694018, 31939071, 79542916, 73045210, 17505051, 49654541]
    l2 = [51706749, 56448162, 5433721, 34751217, 68622131, 18034063, 76397250, 8470054, 34234785, 15826780, 7590196, 12292302, 11391326, 48460313, 22694018, 31939071, 79542916, 73045210, 17505051, 49654541]