Search code examples
pytorchclassificationautoencoderimage-classificationpytorch-lightning

forward() using Pytorch Lightning not giving consistent binary classification results for single VS multiple images


I have trained a Variational Autoencoder (VAE) with an additional fully connected layer after the encoder for binary image classification. It is setup using PyTorch Lightning. The encoder / decoder is resnet18 from PyTorch Lightning Bolts repo.

from pl_bolts.models.autoencoders.components import (
    resnet18_encoder,
    resnet18_decoder
)

class VariationalAutoencoder(LightningModule):

...

    self.first_conv: bool = False
    self.maxpool1: bool = False
    self.enc_out_dim: int = 512
    self.encoder = resnet18_encoder(first_conv, maxpool1)
    self.fc_object_identity = nn.Linear(self.enc_out_dim, 1)


    def forward(self, x):
        x_encoded = self.encoder(x)
        mu = self.fc_mu(x_encoded)
        log_var = self.fc_var(x_encoded)
        p, q, z = self.sample(mu, log_var)

        x_classification_score = torch.sigmoid(self.fc_object_identity(x_encoded))

        return self.decoder(z), x_classification_score

variational_autoencoder = VariationalAutoencoder.load_from_checkpoint(
        checkpoint_path=str(checkpoint_file_path)
    )

with torch.no_grad():
    predicted_images, classification_score = variational_autoencoder(test_images)

The reconstructions work well for single images and multiple images when passed through forward(). However, when I pass multiple images to forward() I get different results for the classification score than if I pass a single image tensor:

# Image 1 (class=1) [1, 3, 64, 64]
x_classification_score = 0.9857

# Image 2 (class=0) [1, 3, 64, 64]
x_classification_score = 0.0175

# Image 1 and 2 [2, 3, 64, 64]
x_classification_score =[[0.8943],
                         [0.1736]]

Why is this happening?


Solution

  • You are using resnet18 which has a torch.nn.BatchNorm2d layer.

    Its behavior changes whether it is in train or eval mode. It calculates mean and variance across batch during training and hence its output is dependent on examples in this batch.

    In evaluation mode mean and variance gathered during training via moving average are used which is batch independent, hence results are the same.