Search code examples
pytorchmultiprocessingfork

Child process hangs when performing inference with PyTorch model


I have a PyTorch model (class Net), together with its saved weights / state dict (net.pth), and I want to perform inference in a multiprocessing environment.

I noticed that I cannot simply create a model instance, load the weights, then share the model with a child process (though I'd have assumed this is possible due to copy-on-write). What happens is that the child hangs on y = model(x), and finally the whole program hangs (due to parent's waitpid).

The following is a minimal reproducible example:

def handler():
    with torch.no_grad():
        x = torch.rand(1, 3, 32, 32)
        y = model(x)

    return y


model = Net()
model.load_state_dict(torch.load("./net.pth"))

pid = os.fork()

if pid == 0:
    # this doesn't get printed as handler() hangs for the child process
    print('child:', handler())
else:
    # everything is fine here
    print('parent:', handler())
    os.waitpid(pid, 0)

If the model loading is done independently for parent & child, i.e. no sharing, then everything works as expected. I have also tried calling share_memory_ on model's tensors, but to no avail.

Am I doing something obviously wrong here?


Solution

  • Seems that sharing the state dict and performing the loading operation in each process solves the problem:

    LOADED = False 
    
    def handler():
        global LOADED
        if not LOADED:
            # each process loads state independently
            model.load_state_dict(state)
            LOADED = True
    
        with torch.no_grad():
            x = torch.rand(1, 3, 32, 32)
            y = model(x)
    
        return y
    
    
    model = Net()
    
    # share the state rather than loading the state dict in parent
    # model.load_state_dict(torch.load("./net.pth"))
    state = torch.load("./net.pth")
    
    pid = os.fork()
    
    if pid == 0:
        print('child:', handler())
    else:
        print('parent:', handler())
        os.waitpid(pid, 0)