Search code examples
pythondeep-learningpytorchautoencoder

My convolutional autoencoder should return the same shape as the input, but it does not


I am using a convolutional autoencoder to reconstruct a spectrogram of shape (1, 100, 592) but it returns an output with shape (1, 90, 586)

My Encoder is:

class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super(Encoder, self).__init__()

        # define convolutional layers for encoding
        #! in_channels is 1 because the spectogram is only 1 channel
        
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=512, kernel_size=(5, 5), stride=(2, 2))
        self.batch1 = nn.BatchNorm2d(512)
        self.conv2 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=(3, 3), stride=(2, 2))
        self.batch2 = nn.BatchNorm2d(256)
        self.conv3 = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=(3, 3), stride=(2, 2))
        self.batch3 = nn.BatchNorm2d(128)
        self.conv4 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=(2, 2), stride=(2, 2))
        self.batch4 = nn.BatchNorm2d(64)
        self.conv5 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=(1, 1), stride=(1, 1))

        # define latent space
        self.z = nn.Linear(out_size_len , latent_dim)

    def forward(self, x):
        x = x.to(device)  # move input to GPU
        x = F.relu(self.batch1(self.conv1(x)))  # apply convolutional layer, batch norm, and ReLU activation
        x = F.relu(self.batch2(self.conv2(x)))
        x = F.relu(self.batch3(self.conv3(x)))
        x = F.relu(self.batch4(self.conv4(x)))
        x = F.relu(self.conv5(x))
        x = torch.flatten(x, start_dim=1)  # flatten output for linear layers
        
        z = F.relu(self.z(x))
        
        return z

    # ! This is the exercise below
    def get_output_size(self, input_shape):
        x = torch.zeros(input_shape)
        x = x.to(device)
        x = self.conv1(x)
        x = self.batch1(x)
        x = self.conv2(x)
        x = self.batch2(x)
        x = self.conv3(x)
        x = self.batch3(x)
        x = self.conv4(x)
        x = self.batch4(x)
        x = self.conv5(x)
        output_size = x.size()[1:]
        return output_size

The Decoder is:

class Decoder(nn.Module):

    def __init__(self, latent_dims):
        super().__init__()

        # linear layer to map latent code to 3D tensor
        self.decoder_lin = nn.Sequential(
            nn.Linear(latent_dims, 128),
            nn.ReLU(True),
            #! The next is the size obtained by the get_output_size
            nn.Linear(128, out_size_len),
            nn.ReLU(True)
        )

        # unflatten 3D tensor to 4D tensor
        self.unflatten = nn.Unflatten(dim=1, unflattened_size=(int(out_size[0]), int(out_size[1]), int(out_size[2])))

        # transposed convolutional layers to gradually increase spatial resolution
        self.dec1 = nn.ConvTranspose2d(in_channels=32, out_channels=64, kernel_size=(1, 1), stride=(1, 1))
        self.batch1 = nn.BatchNorm2d(64)
        self.dec2 = nn.ConvTranspose2d(in_channels=64, out_channels=128, kernel_size=(2, 2), stride=(2, 2))
        self.batch2 = nn.BatchNorm2d(128)
        self.dec3 = nn.ConvTranspose2d(in_channels=128, out_channels=256, kernel_size=(3, 3), stride=(2, 2))
        self.batch3 = nn.BatchNorm2d(256)
        self.dec4 = nn.ConvTranspose2d(in_channels=256, out_channels=512, kernel_size=(3, 3), stride=(2, 2), output_padding=1)
        self.batch4 = nn.BatchNorm2d(512)
        self.dec5 = nn.ConvTranspose2d(in_channels=512, out_channels=1, kernel_size=(3, 3), stride=(2, 2), output_padding=1)

    def forward(self, x):
        # linear layer to map latent code to 3D tensor
        x = self.decoder_lin(x)
        # unflatten 3D tensor to 4D tensor
        x = self.unflatten(x)
        # transposed convolutional layers to gradually increase spatial resolution
        x = F.relu(self.batch1(self.dec1(x)))
        x = F.relu(self.batch2(self.dec2(x)))
        x = F.relu(self.batch3(self.dec3(x)))
        x = F.relu(self.batch4(self.dec4(x)))
        x = self.dec5(x)
        # apply sigmoid activation function to ensure output is between 0 and 1
        x = torch.sigmoid(x)
        return x

The Encoder-Decoder:

class Enc_Dec(nn.Module):
    def __init__(self, latent_dims):
        super(Enc_Dec, self).__init__()
        self.encoder = Encoder(latent_dims)
        self.decoder = Decoder(latent_dims)

    def forward(self, x):
        x = x.to(device) 
        z = self.encoder(x)
        return self.decoder(z)

So, to investigate, I used torchsummary, with the following output:

from torchsummary import summary
summary(Enc_Dec(latent_dims=20), X[0].shape)
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 512, 48, 294]          13,312
       BatchNorm2d-2         [-1, 512, 48, 294]           1,024
            Conv2d-3         [-1, 256, 23, 146]       1,179,904
       BatchNorm2d-4         [-1, 256, 23, 146]             512
            Conv2d-5          [-1, 128, 11, 72]         295,040
       BatchNorm2d-6          [-1, 128, 11, 72]             256
            Conv2d-7            [-1, 64, 5, 36]          32,832
       BatchNorm2d-8            [-1, 64, 5, 36]             128
            Conv2d-9            [-1, 32, 5, 36]           2,080
           Linear-10                   [-1, 20]         115,220
          Encoder-11                   [-1, 20]               0
           Linear-12                  [-1, 128]           2,688
             ReLU-13                  [-1, 128]               0
           Linear-14                 [-1, 5760]         743,040
             ReLU-15                 [-1, 5760]               0
        Unflatten-16            [-1, 32, 5, 36]               0
  ConvTranspose2d-17            [-1, 64, 5, 36]           2,112
      BatchNorm2d-18            [-1, 64, 5, 36]             128
  ConvTranspose2d-19          [-1, 128, 10, 72]          32,896
      BatchNorm2d-20          [-1, 128, 10, 72]             256
  ConvTranspose2d-21         [-1, 256, 21, 145]         295,168
      BatchNorm2d-22         [-1, 256, 21, 145]             512
  ConvTranspose2d-23         [-1, 512, 44, 292]       1,180,160
      BatchNorm2d-24         [-1, 512, 44, 292]           1,024
  ConvTranspose2d-25           [-1, 1, 90, 586]           4,609
          Decoder-26           [-1, 1, 90, 586]               0
================================================================
Total params: 3,902,901
Trainable params: 3,902,901
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.23
Forward/backward pass size (MB): 239.92
Params size (MB): 14.89
Estimated Total Size (MB): 255.04
----------------------------------------------------------------

I see the difference starts in the second ConvTranspose2d in the Decoder . Thus, I started switching some values of stride and padding, but I do not reach the desired output shape.

How can I solve this?


Solution

  • Basically its just playing with the paddings of your decoder. Here is an solution:

    class Decoder(nn.Module):
    
        def __init__(self, latent_dims):
            super().__init__()
    
            # linear layer to map latent code to 3D tensor
            self.decoder_lin = nn.Sequential(
                nn.Linear(latent_dims, 128),
                nn.ReLU(True),
                #! The next is the size obtained by the get_output_size
                nn.Linear(128, 5760),
                nn.ReLU(True)
            )
    
            # unflatten 3D tensor to 4D tensor
            self.unflatten = nn.Unflatten(dim=1, unflattened_size=(32, 5, 36))
    
            # transposed convolutional layers to gradually increase spatial resolution
            self.dec1 = nn.ConvTranspose2d(in_channels=32, out_channels=64, kernel_size=(1, 1), stride=(1, 1))
            self.batch1 = nn.BatchNorm2d(64)
            self.dec2 = nn.ConvTranspose2d(in_channels=64, out_channels=128, kernel_size=(2, 2), stride=(2, 2), output_padding=1)
            self.batch2 = nn.BatchNorm2d(128)
            self.dec3 = nn.ConvTranspose2d(in_channels=128, out_channels=256, kernel_size=(3, 3), stride=(2, 2), output_padding=(1,0))
            self.batch3 = nn.BatchNorm2d(256)
            self.dec4 = nn.ConvTranspose2d(in_channels=256, out_channels=512, kernel_size=(3, 3), stride=(2, 2))
            self.batch4 = nn.BatchNorm2d(512)
            self.dec5 = nn.ConvTranspose2d(in_channels=512, out_channels=1, kernel_size=(3, 3), stride=(2, 2), output_padding=1)
    
        def forward(self, x):
            # linear layer to map latent code to 3D tensor
            x = self.decoder_lin(x)
            # unflatten 3D tensor to 4D tensor
            x = self.unflatten(x)
            # transposed convolutional layers to gradually increase spatial resolution
            x = F.relu(self.batch1(self.dec1(x)))
            x = F.relu(self.batch2(self.dec2(x)))
            x = F.relu(self.batch3(self.dec3(x)))
            x = F.relu(self.batch4(self.dec4(x)))
            x = self.dec5(x)
            # apply sigmoid activation function to ensure output is between 0 and 1
            x = torch.sigmoid(x)
            return x
    

    And the outputs are:

    ----------------------------------------------------------------
            Layer (type)               Output Shape         Param #
    ================================================================
                Linear-1                  [-1, 128]           2,688
                  ReLU-2                  [-1, 128]               0
                Linear-3                 [-1, 5760]         743,040
                  ReLU-4                 [-1, 5760]               0
             Unflatten-5            [-1, 32, 5, 36]               0
       ConvTranspose2d-6            [-1, 64, 5, 36]           2,112
           BatchNorm2d-7            [-1, 64, 5, 36]             128
       ConvTranspose2d-8          [-1, 128, 11, 73]          32,896
           BatchNorm2d-9          [-1, 128, 11, 73]             256
      ConvTranspose2d-10         [-1, 256, 24, 147]         295,168
          BatchNorm2d-11         [-1, 256, 24, 147]             512
      ConvTranspose2d-12         [-1, 512, 49, 295]       1,180,160
          BatchNorm2d-13         [-1, 512, 49, 295]           1,024
      ConvTranspose2d-14          [-1, 1, 100, 592]           4,609