Search code examples
pythonpytorchonnx

How to load pretrained pytorch weight for


I am following this blog https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html want to run pytorch model in onnx runtime . here in example it has given a pretrained weight a URL how to load a pretrained weight from local disk .

# Load pretrained model weights
model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth'
path = "/content/best.pt"
batch_size = 1    # just a random number

# Initialize model with the pretrained weights
map_location = lambda storage, loc: storage
if torch.cuda.is_available():
    map_location = None
torch_model.load_state_dict(model_zoo.load_url(model_url, map_location=map_location))

# set the model to inference mode
torch_model.eval()

I want to load the weight which is defined as Path .


Solution

  • If you want to load the state dict from a path, this is what you should do:

    torch_model.load_state_dict(torch.load(path))
    

    This should work.