Search code examples
pythonpytorchclassificationresnet

What is the difference between saving a model and saving the weights in Pytorch?


I have the following code:

x = cv2.imread('anormal_9979_AVsp2000_ciclo6.png')
   
        transform = T.Compose([T.ToTensor(), T.Resize((64, 64)), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]);
        #transform_img = T.ToTensor()
        x_t =transform(x)
        batch_t = torch.unsqueeze(x_t,0).double()
        torch.nn.CrossEntropyLoss()
        resnet = torchvision.models.wide_resnet101_2(pretrained=False);
        resnet.fc = torch.nn.Linear(resnet.fc.in_features, 2);
        resnet.load_state_dict(torch.load('wide_resnet101_2.pt', map_location=torch.device('cpu')));
        resnet.eval();
        net = resnet(batch_t.float())

        predictions = torch.argmax(net, axis= 1);
        print(predictions);

And when it prints the prediction, the prediction is always wrong. However, when I trained the model and save it, the test_acc is aprox 95%.

In addition, this resnet is trained for 2 classes([0] for Normal and [1] for Anormal).

I tried to modify the grad_fn, but i can't apaprently. Also, at the moment i've seen that you can also save the weights of the pretrained model, but does not the model.pt has inheretly the final weights?


Solution

  • First parameter of nn.Module.load_state_dict is a state_dict, a dictionary of layer-parameter pairs.

    This call does not load the .pt file and copy it into the model, but instead just passes the string as a state_dict.

    resnet.load_state_dict('wide_resnet101_2.pt')
    

    Correct way is to first read the weights using torch.load, and then copy them into the model.

    state_dict = torch.load('wide_resnet101_2.pt')
    resnet.load_state_dict(state_dict)
    

    Torch documentation provides a neat beginner-friendly overview on the matter.

    Although, this also depends on how you saved the model.

    torch.save(model, PATH) # saves the entire model
    model = torch.load(PATH) # loads the entire model with weights
    
    torch.save(model.state_dict(), PATH) # saves state_dict
    # now loading the state_dict first is needed, after which it can
    # be copied into model
    model.load_state_dict(torch.load(PATH))