Search code examples
pytorchautoencoder

Why is my autoencoder not learning the FMNIST dataset?


I am using a simple autoencoder to learn images from the FashionMnist dataset. I have preprocessed the dataset by grayscaling and normalizing it. I did not make the network too deep, to prevent it from creating a direct mapping.

Here's my PyTorch code -

import torch
import torchvision as tv
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torch import nn
import os
from torchviz import make_dot
transforms = tv.transforms.Compose([tv.transforms.Grayscale(num_output_channels=1)])
trainset = tv.datasets.FashionMNIST(root='./data', train=True,
                                        download=True, transform=transforms)
PATH = './ae.pth'
data = trainset.data.float()
data = data/255
# print(trainset.data.shape)
plt.imshow(trainset.data[0], cmap = 'gray')
plt.show()



class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.encode = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 30),
            nn.ReLU()
        )
        self.decode = nn.Sequential(
            nn.Linear(30, 512),
            nn.ReLU(),
            nn.Linear(512, 28*28),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.flatten(x)
        encoded = self.encode(x)
        decoded = self.decode(encoded)
        return decoded
if(os.path.exists(PATH)):
    print("Loading data on cpu")
    device = torch.device('cpu')
    model = NeuralNetwork()
    model.load_state_dict(torch.load(PATH, map_location=device))

else:
    device = "cuda" if torch.cuda.is_available() else "cpu"
    data = data.to(device)
    print(f"Using device = {device}")
    model = NeuralNetwork().to(device)
    # print(model)

    lossFn  = nn.BCELoss()

    optimizer = torch.optim.SGD(model.parameters(), lr = 1e-3)

    for epoch in range(1000):
        print("Epoch = ", epoch)
        optimizer.zero_grad()
        outputs = model(data)
        loss = lossFn(outputs, data.reshape(-1, 784))
        loss.backward()
        optimizer.step()

    torch.save(model.state_dict(), PATH)
    data = data.to("cpu")
    model = model.to("cpu")

pred = model(data)
pred = pred.reshape(-1, 28, 28)
# print(pred.shape)
plt.imshow(pred.detach().numpy()[0], cmap = 'gray')
plt.show()

For testing, I am inputting the following image - enter image description here However, I get this as output - enter image description here


Solution

  • I had an intuition that there was an issue with your loss function. When working with images, distance-based losses such as L1 or L2 losses work really well, as you are essentially measuring how far-away your predictions are from the ground-truth images. This was what I had observed as well, as the loss wasn't converging with BCE and it was rather oscillating.

    I rewrote the entire thing and replaced BCE loss with MSE Loss and in just 50 epochs, the loss has gone down considerably, and it is still going down. Here is the prediction after just 50 epochs -

    enter image description here

    The ground-truth image is -

    enter image description here

    I believe that you can get the loss down much more if you train for longer.

    Here is the full code. I used a dataloader for batchifying and processing the data.

    I also changed the transformations so that the resulting data is a torch tensor.

    import torch
    import torchvision as tv
    import torchvision.transforms as transforms
    import matplotlib.pyplot as plt
    from torch import nn
    from torch.utils.data import DataLoader
    
    transforms = tv.transforms.Compose([
        transforms.Grayscale(num_output_channels=1),
        transforms.ToTensor()
    ])
    
    trainset = tv.datasets.FashionMNIST(root='./data', train=True,
                                            download=True, transform=transforms)
    
    loader = DataLoader(trainset, batch_size=32, num_workers=1, shuffle=True)
    
    class NeuralNetwork(nn.Module):
        def __init__(self):
            super(NeuralNetwork, self).__init__()
            self.flatten = nn.Flatten()
            self.encode = nn.Sequential(
                nn.Linear(28*28, 512),
                nn.ReLU(),
                nn.Linear(512, 30),
                nn.ReLU()
            )
            self.decode = nn.Sequential(
                nn.Linear(30, 512),
                nn.ReLU(),
                nn.Linear(512, 28*28),
                nn.Sigmoid()
            )
    
        def forward(self, x):
            x = self.flatten(x)
            encoded = self.encode(x)
            decoded = self.decode(encoded)
            return decoded
    
    model = NeuralNetwork().to(device)
    lossFn  = nn.MSELoss()
    
    optimizer = torch.optim.SGD(model.parameters(), lr = 1e-2)
    
    epochs = 50
    for epoch in range(epochs):
        for images, labels in loader:
            optimizer.zero_grad()
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = lossFn(outputs, images.reshape(-1, 28*28))
            loss.backward()
            optimizer.step()
        
        print(f'Loss : {loss.item()}')
        
        print(f'Epochs done : {epoch}')
    

    Here is some inference code -

    # infer on some test data
    testset = tv.datasets.FashionMNIST(root='./data', train=False,
                                            download=False, transform=transforms)
    
    testloader = DataLoader(testset, shuffle=False, batch_size=32, num_workers=1)
    
    test_images, test_labels = next(iter(testloader))
    test_images = test_images.to(device)
    predictions = model(test_images)
    
    prediction = predictions[0]
    prediction = prediction.view(1, 28, 28)
    
    prediction = prediction.detach().cpu().numpy()
    
    prediction = prediction.transpose(1, 2, 0)
    
    # plot the prediction
    plt.imshow(prediction, cmap = 'gray')
    plt.show()
    
    # plot the actual image 
    test_image = test_images[0]
    test_image = test_image.detach().cpu().numpy()
    test_image = test_image.transpose(1, 2, 0)
    
    plt.imshow(test_image, cmap='gray')
    plt.show()
    

    This is the loss going down --

    Epochs done : 39
    Loss : 0.04641226679086685
    Epochs done : 40
    Loss : 0.04445071145892143
    Epochs done : 41
    Loss : 0.05033266171813011
    Epochs done : 42
    Loss : 0.04813298210501671
    Epochs done : 43
    Loss : 0.0474831722676754
    Epochs done : 44
    Loss : 0.044186390936374664
    Epochs done : 45
    Loss : 0.049083154648542404
    Epochs done : 46
    Loss : 0.04645842686295509
    Epochs done : 47
    Loss : 0.04586248844861984
    Epochs done : 48
    Loss : 0.0467853844165802
    Epochs done : 49