Search code examples
deep-learningpytorchautoencoderpytorch-lightning

Pretrained lightning-bolts VAE not doing proper inference on training dataset


I'm using the CIFAR-10 pre-trained VAE from lightning-bolts. It should be able to regenerate images with the quality shown on this picture taken from the docs (LHS are the real images, RHS are the generated)

enter image description here

However, when I write a simple script that loads the model, the weights, and tests it over the training set, I get a much worse reconstruction (top row are real images, bottom row are the generated ones):

enter image description here

Here is a link to a self-contained colab notebook that reproduces the steps I've followed to produce the pictures.

Am I doing something wrong on my inference process? Could it be that the weights are not as "good" as the docs claim?

Thanks!


Solution

  • First, the image from the docs you show is for the AE, not the VAE. The results for the VAE look much worse:
    enter image description here enter image description here
    https://pl-bolts-weights.s3.us-east-2.amazonaws.com/vae/vae-cifar10/vae_output.png

    Second, the docs state "Both input and generated images are normalized versions as the training was done with such images." So when you load the data you should specify normalize=True. When you plot your data, you will need to 'unnormalize' the data as well:

    from pl_bolts.datamodules import CIFAR10DataModule
    from pl_bolts.models.autoencoders import VAE
    from pytorch_lightning import Trainer
    import matplotlib.pyplot as plt
    import numpy as np
    import torch
    from torchvision import transforms
    
    torch.manual_seed(17)
    np.random.seed(17)
    
    vae = VAE(32, lr=0.00001)
    vae = vae.from_pretrained("cifar10-resnet18")
    
    dm = CIFAR10DataModule(".", normalize=True)
    dm.prepare_data()
    dm.setup("fit")
    dataloader = dm.train_dataloader()
    
    print(dm.default_transforms())
    mean = torch.tensor(dm.default_transforms().transforms[1].mean)
    std = torch.tensor(dm.default_transforms().transforms[1].std)
    unnormalize = transforms.Normalize((-mean / std).tolist(), (1.0 / std).tolist())
    
    X, _ = next(iter(dataloader))
    vae.eval()
    X_hat = vae(X)
    
    fig, axes = plt.subplots(2, 10, figsize=(10, 2))
    for i in range(10):  
      ax_real = axes[0][i]
      ax_real.imshow(np.transpose(unnormalize(X[i]), (1, 2, 0)))
      ax_real.get_xaxis().set_visible(False)
      ax_real.get_yaxis().set_visible(False)
    
      ax_gen = axes[1][i]
      ax_gen.imshow(np.transpose(unnormalize(X_hat[i]).detach().numpy(), (1, 2, 0)))
      ax_gen.get_xaxis().set_visible(False)
      ax_gen.get_yaxis().set_visible(False)
    

    Which gives something like this: pytorch-lightning VAE reconstruction unnormalized

    Without normalization it looks like: enter image description here