I am using a simple autoencoder to learn images from the FashionMnist dataset. I have preprocessed the dataset by grayscaling and normalizing it. I did not make the network too deep, to prevent it from creating a direct mapping.
Here's my PyTorch code -
import torch
import torchvision as tv
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torch import nn
import os
from torchviz import make_dot
transforms = tv.transforms.Compose([tv.transforms.Grayscale(num_output_channels=1)])
trainset = tv.datasets.FashionMNIST(root='./data', train=True,
download=True, transform=transforms)
PATH = './ae.pth'
data = trainset.data.float()
data = data/255
# print(trainset.data.shape)
plt.imshow(trainset.data[0], cmap = 'gray')
plt.show()
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten()
self.encode = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 30),
nn.ReLU()
)
self.decode = nn.Sequential(
nn.Linear(30, 512),
nn.ReLU(),
nn.Linear(512, 28*28),
nn.Sigmoid()
)
def forward(self, x):
x = self.flatten(x)
encoded = self.encode(x)
decoded = self.decode(encoded)
return decoded
if(os.path.exists(PATH)):
print("Loading data on cpu")
device = torch.device('cpu')
model = NeuralNetwork()
model.load_state_dict(torch.load(PATH, map_location=device))
else:
device = "cuda" if torch.cuda.is_available() else "cpu"
data = data.to(device)
print(f"Using device = {device}")
model = NeuralNetwork().to(device)
# print(model)
lossFn = nn.BCELoss()
optimizer = torch.optim.SGD(model.parameters(), lr = 1e-3)
for epoch in range(1000):
print("Epoch = ", epoch)
optimizer.zero_grad()
outputs = model(data)
loss = lossFn(outputs, data.reshape(-1, 784))
loss.backward()
optimizer.step()
torch.save(model.state_dict(), PATH)
data = data.to("cpu")
model = model.to("cpu")
pred = model(data)
pred = pred.reshape(-1, 28, 28)
# print(pred.shape)
plt.imshow(pred.detach().numpy()[0], cmap = 'gray')
plt.show()
For testing, I am inputting the following image - However, I get this as output -
I had an intuition that there was an issue with your loss function. When working with images, distance-based losses such as L1 or L2 losses work really well, as you are essentially measuring how far-away your predictions are from the ground-truth images. This was what I had observed as well, as the loss wasn't converging with BCE and it was rather oscillating.
I rewrote the entire thing and replaced BCE loss with MSE Loss and in just 50 epochs, the loss has gone down considerably, and it is still going down. Here is the prediction after just 50 epochs -
The ground-truth image is -
I believe that you can get the loss down much more if you train for longer.
Here is the full code. I used a dataloader
for batchifying and processing the data.
I also changed the transformations so that the resulting data is a torch tensor
.
import torch
import torchvision as tv
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torch import nn
from torch.utils.data import DataLoader
transforms = tv.transforms.Compose([
transforms.Grayscale(num_output_channels=1),
transforms.ToTensor()
])
trainset = tv.datasets.FashionMNIST(root='./data', train=True,
download=True, transform=transforms)
loader = DataLoader(trainset, batch_size=32, num_workers=1, shuffle=True)
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten()
self.encode = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 30),
nn.ReLU()
)
self.decode = nn.Sequential(
nn.Linear(30, 512),
nn.ReLU(),
nn.Linear(512, 28*28),
nn.Sigmoid()
)
def forward(self, x):
x = self.flatten(x)
encoded = self.encode(x)
decoded = self.decode(encoded)
return decoded
model = NeuralNetwork().to(device)
lossFn = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr = 1e-2)
epochs = 50
for epoch in range(epochs):
for images, labels in loader:
optimizer.zero_grad()
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = lossFn(outputs, images.reshape(-1, 28*28))
loss.backward()
optimizer.step()
print(f'Loss : {loss.item()}')
print(f'Epochs done : {epoch}')
Here is some inference code -
# infer on some test data
testset = tv.datasets.FashionMNIST(root='./data', train=False,
download=False, transform=transforms)
testloader = DataLoader(testset, shuffle=False, batch_size=32, num_workers=1)
test_images, test_labels = next(iter(testloader))
test_images = test_images.to(device)
predictions = model(test_images)
prediction = predictions[0]
prediction = prediction.view(1, 28, 28)
prediction = prediction.detach().cpu().numpy()
prediction = prediction.transpose(1, 2, 0)
# plot the prediction
plt.imshow(prediction, cmap = 'gray')
plt.show()
# plot the actual image
test_image = test_images[0]
test_image = test_image.detach().cpu().numpy()
test_image = test_image.transpose(1, 2, 0)
plt.imshow(test_image, cmap='gray')
plt.show()
This is the loss going down --
Epochs done : 39
Loss : 0.04641226679086685
Epochs done : 40
Loss : 0.04445071145892143
Epochs done : 41
Loss : 0.05033266171813011
Epochs done : 42
Loss : 0.04813298210501671
Epochs done : 43
Loss : 0.0474831722676754
Epochs done : 44
Loss : 0.044186390936374664
Epochs done : 45
Loss : 0.049083154648542404
Epochs done : 46
Loss : 0.04645842686295509
Epochs done : 47
Loss : 0.04586248844861984
Epochs done : 48
Loss : 0.0467853844165802
Epochs done : 49