I have a PyTorch model (class Net
), together with its saved weights / state dict (net.pth
), and I want to perform inference in a multiprocessing environment.
I noticed that I cannot simply create a model instance, load the weights, then share the model with a child process (though I'd have assumed this is possible due to copy-on-write). What happens is that the child hangs on y = model(x)
, and finally the whole program hangs (due to parent's waitpid
).
The following is a minimal reproducible example:
def handler():
with torch.no_grad():
x = torch.rand(1, 3, 32, 32)
y = model(x)
return y
model = Net()
model.load_state_dict(torch.load("./net.pth"))
pid = os.fork()
if pid == 0:
# this doesn't get printed as handler() hangs for the child process
print('child:', handler())
else:
# everything is fine here
print('parent:', handler())
os.waitpid(pid, 0)
If the model loading is done independently for parent & child, i.e. no sharing, then everything works as expected. I have also tried calling share_memory_
on model's tensors, but to no avail.
Am I doing something obviously wrong here?
Seems that sharing the state dict and performing the loading operation in each process solves the problem:
LOADED = False
def handler():
global LOADED
if not LOADED:
# each process loads state independently
model.load_state_dict(state)
LOADED = True
with torch.no_grad():
x = torch.rand(1, 3, 32, 32)
y = model(x)
return y
model = Net()
# share the state rather than loading the state dict in parent
# model.load_state_dict(torch.load("./net.pth"))
state = torch.load("./net.pth")
pid = os.fork()
if pid == 0:
print('child:', handler())
else:
print('parent:', handler())
os.waitpid(pid, 0)