I am aiming to load my CNN model. I have written below lines of code. This is my CNN model archiecture.
class ConvNet(nn.Module):
def __init__(self):
super(ConvNet, self).__init__()
self.conv1 = nn.Conv2d(in_channels=2, out_channels=4, kernel_size=4)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(4, 8, 4)
self.fc1 = nn.Linear(8 * 6 * 6, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 1)
def forward(self, x):
# -> n, 3, 32, 32
x = self.pool(F.relu(self.conv1(x))) # -> n, 6, 14, 14
x = self.pool(F.relu(self.conv2(x))) # -> n, 16, 5, 5
x = x.flatten() # -> n, 400
x = F.relu(self.fc1(x)) # -> n, 120
x = F.relu(self.fc2(x)) # -> n, 84
# x = nn.LeakyReLU(0.1)(self.fc3(x)) # -> n, 10
x = self.fc3(x) # -> n, 10
return x
I have trained this model on one dataset and I want to test this model on another dataset therefore I saved the model by writing lines of code as below,
filepath = r'C:/Users/Q559366/Desktop/code check/CNN_Model/cnnMLmodel300Epoch.pth'
torch.save(model.state_dict(), filepath)
However, while loading the model, I got an error.
import torch
from src.data import CarBonnetSource
from src.model import ConvNet
model = ConvNet
model.load_state_dict(torch.load(filepath))
model.eval()
But I got an error: load_state_dict() missing 1 required positional argument: 'state_dict'
What can i do to successfully load my model.
Your ConvNet
has not been instantiated yet, that's why calling load_state_dict
will throw an error. What you should do to fix it is:
model = ConvNet()
model.load_state_dict(torch.load(filepath))
model.eval()