Search code examples
pytorchunet-neural-network

Unet pytorch dimension mismatch


I got the following U-net architecture causing problems:

class UNet(nn.Module): 
    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()
        self.encoder1 = self.double_conv(in_channels, 64)
        self.encoder2 = self.down(64, 128)
        self.encoder3 = self.down(128, 256)
        self.encoder4 = self.down(256, 512)
        self.bottleneck = self.double_conv(512, 1024)
        self.decoder4 = self.up(1024, 512)
        self.decoder3 = self.up(512, 256)
        self.decoder2 = self.up(256, 128)
        self.decoder1 = self.up(128, 64)
        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1) # SAME convolution/padding

    def double_conv(self, in_channels, out_channels): # Convo Block
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )

    def down(self, in_channels, out_channels):
        return nn.Sequential(
            nn.MaxPool2d(kernel_size=2, stride=2),
            self.double_conv(in_channels, out_channels),
        )

    def up(self, in_channels, out_channels):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
            self.double_conv(in_channels, out_channels),
        )

    def forward(self, x):
        # Encoder
        enc1 = self.encoder1(x)  # Output: [1, 64, 256, 256]
        print("enc1.shape",enc1.shape)
        enc2 = self.encoder2(enc1)  # Output: [1, 128, 128, 128]
        print("enc2.shape",enc2.shape)
        enc3 = self.encoder3(enc2)  # Output: [1, 256, 64, 64]
        print("enc3.shape",enc3.shape)
        enc4 = self.encoder4(enc3)  # Output: [1, 512, 32, 32]
        print("enc4.shape",enc4.shape)
        bottleneck_output = self.bottleneck(enc4)  # Output: [1, 1024, 32, 32]
        print("bottleneck_output",bottleneck_output.shape)
        
        # Decoder
        dec4 = self.decoder4(bottleneck_output)#bottleneck_output)  # Output: [1, 512, 64, 64]
        print(dec4.shape)
        dec4 = torch.cat((dec4, enc4), dim=1)  # skip connect, Concatenate: [1, 1024, 64, 64]
        dec4 = self.double_conv(1024, 512)(dec4)  # Corrected input channels to 1024

        dec3 = self.decoder3(dec4)  # Output: [1, 256, 128, 128]
        dec3 = torch.cat((dec3, enc3), dim=1)  # Concatenate: [1, 512, 128, 128]
        dec3 = self.double_conv(512, 256)(dec3)  # Corrected input channels to 512

        dec2 = self.decoder2(dec3)  # Output: [1, 128, 256, 256]
        dec2 = torch.cat((dec2, enc2), dim=1)  # Concatenate: [1, 256, 256, 256]
        dec2 = self.double_conv(256, 128)(dec2)  # Corrected input channels to 256

        dec1 = self.decoder1(dec2)  # Output: [1, 64, 512, 512]
        dec1 = torch.cat((dec1, enc1), dim=1)  # Concatenate: [1, 128, 512, 512]
        dec1 = self.double_conv(128, 64)(dec1)  # Corrected input channels to 128

        return self.final_conv(dec1)  # Output: [1, 1, 512, 512]```

When executing in a main method via

unet = UNet(in_channels=1, out_channels=1)
sample_input = torch.randn(1, 1, 256, 256)
output = unet(sample_input)

I get:

enc1.shape torch.Size([1, 64, 256, 256])
enc2.shape torch.Size([1, 128, 128, 128])
enc3.shape torch.Size([1, 256, 64, 64])
enc4.shape torch.Size([1, 512, 32, 32])
bottleneck_output torch.Size([1, 1024, 32, 32])

and the following error:

---> 55 dec4 = self.decoder4(bottleneck_output)

RuntimeError: Given groups=1, weight of size [512, 1024, 3, 3], expected input[1, 512, 64, 64] to have 1024 channels, but got 512 channels instead

So the problem apparently is the bottleneck_output shape which does have 1024 channels, but the decoder4 does not seem to recognise it or sth. like that.

I tried matching the dimensions and other things like an align function but nothing worked so far. Also printing out the output shapes didn't really help. Thanks for any hints.


Solution

  • Your problem is with the definition of up method:

    def up(self, in_channels, out_channels):
            return nn.Sequential(
                nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
                self.double_conv(in_channels, out_channels),
            )
    

    ConvTranspose2d outputs a tensor with out_channels channels but double_conv expects an input tensor of in_channels channels.

    You should probably use something like:

    def up(self, in_channels, out_channels):
            return nn.Sequential(
                nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
                self.double_conv(out_channels, out_channels), # NOTE CHANGE HERE
            )