Search code examples
pytorchmodelload

load_state_dict() missing 1 required positional argument: 'state_dict'


I am aiming to load my CNN model. I have written below lines of code. This is my CNN model archiecture.

class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=2, out_channels=4, kernel_size=4)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(4, 8, 4)
        self.fc1 = nn.Linear(8 * 6 * 6, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 1)

def forward(self, x):
    # -> n, 3, 32, 32
    x = self.pool(F.relu(self.conv1(x)))  # -> n, 6, 14, 14
    x = self.pool(F.relu(self.conv2(x)))  # -> n, 16, 5, 5
    x = x.flatten()           # -> n, 400
    x = F.relu(self.fc1(x))               # -> n, 120
    x = F.relu(self.fc2(x))               # -> n, 84
    # x = nn.LeakyReLU(0.1)(self.fc3(x))                      # -> n, 10
    x = self.fc3(x)                      # -> n, 10
    return x

I have trained this model on one dataset and I want to test this model on another dataset therefore I saved the model by writing lines of code as below,

filepath = r'C:/Users/Q559366/Desktop/code check/CNN_Model/cnnMLmodel300Epoch.pth'
torch.save(model.state_dict(), filepath) 

However, while loading the model, I got an error.

import torch
from src.data import CarBonnetSource
from src.model import ConvNet

model = ConvNet
model.load_state_dict(torch.load(filepath))
model.eval()

But I got an error: load_state_dict() missing 1 required positional argument: 'state_dict'

What can i do to successfully load my model.


Solution

  • Your ConvNet has not been instantiated yet, that's why calling load_state_dict will throw an error. What you should do to fix it is:

    model = ConvNet()
    model.load_state_dict(torch.load(filepath))
    model.eval()