Search code examples
neural-networkpytorchdata-visualizationautoencoder

Visualize Autoencoder Output


I come with a pretty noob question but I'm stuck... I have created a Autoencoder with Pytorch and I trained it with the typical MNIST dataset and so on:

class Autoencoder(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        self.encoder_hidden_layer = nn.Linear(
            in_features=kwargs["input_shape"], out_features=kwargs["embedding_dim"]
        )
        self.encoder_output_layer = nn.Linear(
            in_features=kwargs["embedding_dim"], out_features=kwargs["embedding_dim"]
        )
        self.decoder_hidden_layer = nn.Linear(
            in_features=kwargs["embedding_dim"], out_features=kwargs["embedding_dim"]
        )
        self.decoder_output_layer = nn.Linear(
            in_features=kwargs["embedding_dim"], out_features=kwargs["input_shape"]
        )

    def forward(self, features):
        activation = self.encoder_hidden_layer(features)
        activation = torch.relu(activation)

        code = self.encoder_output_layer(activation)
        code = torch.relu(code)
        
        activation = self.decoder_hidden_layer(code)
        activation = torch.relu(activation)

        activation = self.decoder_output_layer(activation)
        reconstructed = torch.relu(activation)

        return reconstructed

model = Autoencoder(input_shape=784, embedding_dim=128)

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001) 

What I want now is to visualize the reconstructed images, but I don't know how to do it. I know it's quite simple but I cannot find a way. I know that the shape of the output is [128,784] because the batch_size is 128 and 784 is 28x28(x1channel).

Could anyone please tell me how could I get an image from my reconstructed tensor?

Thank you so much!


Solution

  • First you will have to broadcast the tensor into 128x28x28:

    reconstructed = x.reshape(128, 1, 28, 28)
    

    Then, you can convert one of the batch elements into a PIL image using torchvision's functions. The following will show the first image:

    import torchvision.transforms as T
    img = T.ToPILImage()(reconstructed[0])
    img.show()