Search code examples
pytorchtorchvision

Torchvision model cannot be loaded from storage when no GPU availabe


I trained a torchvision mask r-cnn model on GPU and saved it to disk using torch.save(model, model_name). On another machine, without GPU, I try to load it again using torch.load(model_name). The model cannot be deserializised because torch does not know about device cuda:0.

How can I 'convert' such a model to be used on non-GPU environments? I assume it is best practice to move a model to CPU before saving it?


Solution

  • torch.load() has an argument map_location where you can specify the device. So you can use

    torch.load(..., map_location='cpu')
    

    or specify any other device to directly load it there.