Search code examples
pythonpytorchshapestorchconv-neural-network

Custom CNN gives wrong output shape


I need some help. I am trying to make a custom CNN, which should accept one channel images and do binary classification. This is the model:

class custom_small_CNN(nn.Module):

    def __init__(self, input_channels=1, output_features=1):
        super(custom_small_CNN, self).__init__()

        self.input_channels = input_channels
        self.output_features = output_features

        self.conv1 = nn.Conv2d(self.input_channels, 8, kernel_size=(7, 7), stride=(2, 2), padding=(6, 6), dilation=(2, 2))
        self.conv2 = nn.Conv2d(8, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), dilation=(1, 1))
        self.pool = nn.MaxPool2d(kernel_size=(2, 2))
        self.fc1 = nn.Linear(in_features=1024, out_features=self.output_features, bias=True)
        self.dropout = nn.Dropout(p=0.5)
        self.softmax = nn.Softmax(dim=1)
        self.net_name = 'Custom_Small_CNN'

        self.net = nn.Sequential(self.conv1, self.pool, self.conv2, self.pool, self.fc1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool(x)
        #x = self.dropout(x)
        x = self.conv2(x)
        x = self.pool(x)
        x = x.view(-1, 1024)
        x = self.dropout(x)
        x = self.fc1(x)
        if not self.output_features == 1:
            x = self.softmax(x)
        return x

However, when I put an example batch with 4 images (all zeros) in the model like this:

x = torch.from_numpy(np.zeros((4, 1, 256, 256))).float()
net = custom_small_CNN(output_features=2, input_channels=1).float()
output = net(x)

the output has shape torch.Size([16, 2]) instead of torch.Size([4, 2]), which is what I want and what e.g. a ResNet delivers as an output. What am I missing? Thanks!


Solution

  • When you apply pooling layer, it returns (batch_size, 2, 2, num_filters), so when you reshape it x = x.view(-1, 1024), it results in (batch_size * 4, num_filters) as shape.

    Instead of reshaping like that you should either flatten or average the output of pooling layer. Flattening is most commonly used here.

    So, replacing following line

    x = x.view(-1, 1024)
    

    with

    x = nn.Flatten()(x)
    

    would result in correct final output shape