Search code examples
pythontensorflowdeep-learningpytorchautoencoder

What could cause a VAE(Variational AutoEncoder) to output random noise even after training?


I have trained a VAE on CIFAR10 data-set. However, when I try to generate images from the VAE all I get is a bunch of gray noise back. The implementation of this VAE follows the implementation from the book Generative Deep Learning, but instead of TensorFlow the code uses PyTorch.

The notebook containing the training as well as the generation can be found here, while the actual implementation of the VAE can be found here.

I have tried:

  1. Disabling dropouts.
  2. Increasing the dimension of the latent space.

None of the methods show any improvement at all.

I have verified that:

  1. The input size matches the output size
  2. Back-propagation runs successfully as the loss decreases during training.

Solution

  • Thanks for providing the code and a link to a Colab notebook! +1! Also, your code is well-written and easy to read. Unless I missed something, I think there are two problems with your code:

    1. The data normalization
    2. The implementation of the VAE loss.

    About 1., your CIFAR10DataModule class normalizes the RGB channels of the CIFAR10 images using mean = 0.5 and std = 0.5. Since the pixel values are initially in [0,1] range, the normalized images have pixel values in the [-1,1] range. However, your Decoder class applies a nn.Sigmoid() activation to the reconstructed images. Therefore, your reconstructed images have pixel values in the [0,1] range. I suggest to remove this mean-std normalization so that both the "true" images and the reconstructed images have their pixel values in the [0,1] range.

    About 2.: since you're dealing with RGB images the MSE loss makes sense. The idea behind the MSE loss is the "Gaussian decoder". This decoder assumes the pixel values of a "true image" is generated by independent Gaussian distributions whose mean is the pixel values of the reconstructed image (i.e. the output of the decoder) and with a given variance. Your implementation of the reconstruction loss (namely r_loss = F.mse_loss(predictions, targets)) is equivalent to a fixed variance. Using ideas from this paper, we can do better and obtain an analytic expression for the "optimal value" of this variance parameter. Finally, the reconstruction loss should be summed over all pixels (reduction = 'sum'). To understand why, have a look at analytic expression of the reconstruction loss (see, for instance, this blog post which considers the BCE loss).

    Here is what the refactored LitVAE class looks like:

    class LitVAE(pl.LightningModule):
        def __init__(self,
                     learning_rate: float = 0.0005,
                     **kwargs) -> None:
            """
            Parameters
            ----------
            - `learning_rate: float`:
                learning rate for the optimizer
            - `**kwargs`:
                arguments to pass to the variational autoencoder constructor
            """
            super(LitVAE, self).__init__()
            
            self.learning_rate = learning_rate 
    
            self.vae = VariationalAutoEncoder(**kwargs)
    
        def forward(self, x) -> _tensor_size_3_t: 
            return self.vae(x)
    
        def training_step(self, batch, batch_idx):
            r_loss, kl_loss, sigma_opt = self.shared_step(batch)
            loss = r_loss + kl_loss
            
            self.log("train_loss_step", loss)
            return {"loss": loss, 'log':{"r_loss": r_loss / len(batch[0]), "kl_loss": kl_loss / len(batch[0]), 'sigma_opt': sigma_opt}}
    
        def training_epoch_end(self, outputs) -> None:
            # add computation graph
            if(self.current_epoch == 0):
                sample_input = torch.randn((1, 3, 32, 32))
                sample_model = LitVAE(**MODEL_PARAMS)
                
                self.logger.experiment.add_graph(sample_model, sample_input)
                
            epoch_loss = self.average_metric(outputs, "loss")
            self.logger.experiment.add_scalar("train_loss_epoch", epoch_loss, self.current_epoch)
    
        def validation_step(self, batch, batch_idx):
            r_loss, kl_loss, _ = self.shared_step(batch)
            loss = r_loss + kl_loss
    
            self.log("valid_loss_step", loss)
    
            return {"loss": loss}
    
        def validation_epoch_end(self, outputs) -> None:
            epoch_loss = self.average_metric(outputs, "loss")
            self.logger.experiment.add_scalar("valid_loss_epoch", epoch_loss, self.current_epoch)
    
        def test_step(self, batch, batch_idx):
            r_loss, kl_loss, _ = self.shared_step(batch)
            loss = r_loss + kl_loss
            
            self.log("test_loss_step", loss)
            return {"loss": loss}
    
        def test_epoch_end(self, outputs) -> None:
            epoch_loss = self.average_metric(outputs, "loss")
            self.logger.experiment.add_scalar("test_loss_epoch", epoch_loss, self.current_epoch)
    
        def configure_optimizers(self):
            return optim.Adam(self.parameters(), lr=self.learning_rate)
            
        def shared_step(self, batch) -> torch.TensorType: 
            # images are both samples and targets thus original 
            # labels from the dataset are not required
            true_images, _ = batch
    
            # perform a forward pass through the VAE 
            # mean and log_variance are used to calculate the KL Divergence loss 
            # decoder_output represents the generated images 
            mean, log_variance, generated_images = self(true_images)
    
            r_loss, kl_loss, sigma_opt = self.calculate_loss(mean, log_variance, generated_images, true_images)
            return r_loss, kl_loss, sigma_opt
    
        def calculate_loss(self, mean, log_variance, predictions, targets):
            mse = F.mse_loss(predictions, targets, reduction='mean')
            log_sigma_opt = 0.5 * mse.log()
            r_loss = 0.5 * torch.pow((targets - predictions) / log_sigma_opt.exp(), 2) + log_sigma_opt
            r_loss = r_loss.sum()
            kl_loss = self._compute_kl_loss(mean, log_variance)
            return r_loss, kl_loss, log_sigma_opt.exp()
    
        def _compute_kl_loss(self, mean, log_variance): 
            return -0.5 * torch.sum(1 + log_variance - mean.pow(2) - log_variance.exp())
    
        def average_metric(self, metrics, metric_name):
            avg_metric = torch.stack([x[metric_name] for x in metrics]).mean()
            return avg_metric
    

    After 10 epochs, that's what the reconstructed images look like:

    enter image description here