Search code examples
pythonpytorchconv-neural-networktorch

PyTorch CNN Different Input Size


Hello Guys I have a question about different Input Sizes.

My training set and validation dataset have an input Size of 256 and for my prediction (with an unseen Test Dataset) I have an input size of 496.

class Net(nn.Module):
    def __init__(self, shape):
        super(Net,self).__init__()
        self.conv1 = nn.Conv1d(shape,1,1)
        self.batch1 = nn.BatchNorm1d(1)
        self.avgpl1 = nn.AvgPool1d(1, stride=1)
        self.fc1 = nn.Linear(1,3)
    
    #forward method 
    def forward(self,x):
        x = self.conv1(x)
        x = self.batch1(x)
        x = F.relu(x)
        x = self.avgpl1(x)
        x = torch.flatten(x,1)
        x = F.log_softmax(self.fc1(x))
        return x

I saved the model and wanna use it also for my prediction.

Error Message is:

Input In [244], in predict_data(prediction_data, model_path, data_config, context)
     25 new_model = Net(shape_preprocessed_data)
     26 # load the previously saved state_dict
---> 27 new_model.load_state_dict(torch.load("NetModel.pth"))
     29 # check if predictions of models are equal
     30 
     31 # generate random input of size (N,C,H,W)
     32 
     33 # switch to eval mode for both models
     34 model = model.eval()

    RuntimeError: Error(s) in loading state_dict for Net:
    size mismatch for conv1.weight: copying a param with shape 
    torch.Size([1, 256, 1]) from checkpoint, the shape in current model is torch.Size([1, 494, 1]).

How can I solve this?

Solution

  • You could reshape/downsample the input as the first step of the forward pass in your model. This can be done using the torch.nn.functional.interpolate function.

    For example:

    class Net(nn.Module):
    def __init__(self, shape):
        super(Net,self).__init__()
        self.input_shape = shape
        self.conv1 = nn.Conv1d(shape,1,1)
        self.batch1 = nn.BatchNorm1d(1)
        self.avgpl1 = nn.AvgPool1d(1, stride=1)
        self.fc1 = nn.Linear(1,3)
    
    #forward method 
    def forward(self,x):
        x = torch.nn.functional.interpolate(x, size=self.input_shape)
        x = self.conv1(x)
        x = self.batch1(x)
        x = F.relu(x)
        x = self.avgpl1(x)
        x = torch.flatten(x,1)
        x = F.log_softmax(self.fc1(x))
        return x
    

    Your test images would then be downsampled to size 256 in order to be compatible with the model.