Search code examples
pythonpytorchpython-multiprocessingfreeze

Python multiprocessing.Process hangs when large PyTorch tensors are initialised in both processes


Why does the code shown below either finish normally or hang depending on which lines are commented/uncommented, as described in the table below?

Summary of table: if I initialise sufficiently large tensors in both processes without using "spawn", the program hangs. I can fix it by making either tensor smaller, or by using "spawn".

Note:

  1. All memory is purely CPU, I don't even have CUDA installed on this computer
  2. This issue does not occur if I replace torch with numpy, even if I make the array size 10x larger
  3. Version information: Ubuntu 22.04.1 LTS, Python 3.10.12, torch 2.1.2+cpu
Uncommented Commented Behaviour
(1), (4) (2), (3), (5) Hang
(2), (4) (1), (3), (5) OK
(1), (5) (2), (3), (4) OK
(1), (3), (4) (2), (5) OK
import multiprocessing as mp
import torch

def train():
    print("start of train")
    x = torch.arange(100000)            # (1)
    x = torch.arange(10000)             # (2)
    print("end of train")

if __name__ == "__main__":
    mp.set_start_method('spawn')        # (3)
    x = torch.arange(100000)            # (4)
    x = torch.arange(10000)             # (5)
    p = mp.Process(target=train)
    p.start()
    p.join()

Solution

  • The reason your program hangs is because you are trying to fork a multithreaded process, which is destined for trouble. As stated in the multiprocessing docs:

    Note that safely forking a multithreaded process is problematic.

    You may think that your process is not creating other threads, but it turns out that just importing pytorch creates a bunch of background threads in order to (#TODO: insert unknown wizardry here). I verified this using a simple test program:

    import time
    time.sleep(10)
    
    import torch
    time.sleep(10)
    
    x = torch.arange(100000)
    time.sleep(10)
    

    During the first sleep, we have only one thread: A single thread running

    During the second sleep (after importing torch): 20 threads running

    And for good measure, here's what we get during the third sleep, after calling arange: 33 threads running

    It's hard to say what exactly is happening to cause this specific deadlock, but it's likely that a mutex is being copied by the fork operation while in a locked state, and the copy never gets unlocked.
    Regardless, the takeaway should be: Use the spawn or forkserver start methods when you already have multiple threads running, avoid fork unless you are willing to do a lot of tricky manual work to ensure all the forked threads will play nice (and you have a VERY good reason to do so). That's probably not feasible with torch, so just avoid fork.
    By the way, pytorch has a page of best practices for multiprocessing, which you might find helpful.