Pretrained lightning-bolts VAE not doing proper inference on training dataset

I'm using the CIFAR-10 pre-trained VAE from lightning-bolts. It should be able to regenerate images with the quality shown on this picture taken from the docs (LHS are the real images, RHS are the generated)

However, when I write a simple script that loads the model, the weights, and tests it over the training set, I get a much worse reconstruction (top row are real images, bottom row are the generated ones):

Here is a link to a self-contained colab notebook that reproduces the steps I've followed to produce the pictures.

Am I doing something wrong on my inference process? Could it be that the weights are not as "good" as the docs claim?



  • First, the image from the docs you show is for the AE, not the VAE. The results for the VAE look much worse:
    Second, the docs state "Both input and generated images are normalized versions as the training was done with such images." So when you load the data you should specify normalize=True. When you plot your data, you will need to 'unnormalize' the data as well:

    from pl_bolts.datamodules import CIFAR10DataModule
    from pl_bolts.models.autoencoders import VAE
    from pytorch_lightning import Trainer
    import matplotlib.pyplot as plt
    import numpy as np
    import torch
    from torchvision import transforms
    vae = VAE(32, lr=0.00001)
    vae = vae.from_pretrained("cifar10-resnet18")
    dm = CIFAR10DataModule(".", normalize=True)
    dataloader = dm.train_dataloader()
    mean = torch.tensor(dm.default_transforms().transforms[1].mean)
    std = torch.tensor(dm.default_transforms().transforms[1].std)
    unnormalize = transforms.Normalize((-mean / std).tolist(), (1.0 / std).tolist())
    X, _ = next(iter(dataloader))
    X_hat = vae(X)
    fig, axes = plt.subplots(2, 10, figsize=(10, 2))
    for i in range(10):  
      ax_real = axes[0][i]
      ax_real.imshow(np.transpose(unnormalize(X[i]), (1, 2, 0)))
      ax_gen = axes[1][i]
      ax_gen.imshow(np.transpose(unnormalize(X_hat[i]).detach().numpy(), (1, 2, 0)))

