Search code examples
asynchronousmultiprocessingqueuepython-multiprocessingbatch-normalization

Is mp.Queue that limited in memory?


So, I've impemented my synchronous batch normalization layer. I've tried to test it by running on differenet set of parameters by this way: each mp.Process runs a function, which makes forward and backward passes and store output and grad values in a common mp.Queue ( queue.put((out.detach().numpy(), grad.detach().numpy())) ). If my params are batch_size=32 and num_features=128, it works fine. But if I increase num_features to 256, this stops working.

process 0 started
All processes started
forward done
backward done
sync processor must end

The function, which is running in each process, reaches the last print, but does not exit the function, so joins do not pass in main function.

So, as I understand, mp.queue do not let the 'parallel' function reach its end

I've tried to put (1,1) in queue instead of arrays and it worked. So the problem is in size of objects I put in queue. Is 256 * 32 too much? What is the memory limit of objects I can put in? Didn't find it in documentation.

Or this is not the problem?


Solution

  • I've figured this out.

    Prevously my main function was like this:

    processes = []
    q = mp.Queue()
    for rank, mini_batch in enumerate(mini_batches):
        p = Process(target=sync_processor, args=(mini_batch, sync_bn, rank, q, num_workers, imitate_to, port))
        p.start()
        processes.append(p)
    
    for p in processes:
        p.join()
        
    out = []
    grad = []
    
    for _ in range(q.qsize()):
        elem = q.get()
        out.append(torch.tensor(elem[0]))
        grad.append(torch.tensor(elem[1]))
            
        
    

    So I assumed, that my processes would fill up the queue, then I finish them and get data from queue in my main process. But after I read this now I understand (well, I hope I do), mp.queue needs to have a 'consumer' process. So I need to get data from queue before I finish all the processes. Now code looks like this and it works:

    processes = []
    q = mp.Queue()
    
    for rank, mini_batch in enumerate(mini_batches):
        p = Process(target=sync_processor, args=(mini_batch, sync_bn, rank, q, num_workers, imitate_to, port))
        p.start()
        processes.append(p)
        
    out = []
    grad = []
        
    
    for _ in range(num_workers):
        elem = q.get()
        out.append(torch.tensor(elem[0]))
        grad.append(torch.tensor(elem[1]))
            
    for p in processes:
        p.join()