Search code examples
pythonmachine-learningpytorchresnet

Resnet inconsistency between train and eval mode


I'm trying to implement the Resnet in torch. But I found the output of the forward pass varies greatly between train and eval mode. Since the train and eval mode doesn't affect anything besides batch norm and dropout, I don't know if the results make sense.

Below is my test code:

import torch
from torch import nn
from torchvision import models

class resnet_lstm(torch.nn.Module):
    def __init__(self):
        super(resnet_lstm, self).__init__()
        resnet = models.resnet50(pretrained=True)
        self.share = torch.nn.Sequential()
        self.share.add_module("conv1", resnet.conv1)
        self.share.add_module("bn1", resnet.bn1)  # Use BatchNorm3d
        self.share.add_module("relu", resnet.relu)
        self.share.add_module("maxpool", resnet.maxpool)
        self.share.add_module("layer1", resnet.layer1)
        self.share.add_module("layer2", resnet.layer2)
        self.share.add_module("layer3", resnet.layer3)
        self.share.add_module("layer4", resnet.layer4)
        self.share.add_module("avgpool", resnet.avgpool)
        self.fc = nn.Sequential(nn.Linear(2048, 512),
                                nn.ReLU(),
                                nn.Linear(512, 7))

    def forward(self, x):
        x = x.view(-1, 3, 224, 224)
        x = self.share(x)
        return x
    
model = resnet_lstm()

input_ = torch.randn(1, 3, 224, 224)
model.train()
print("train mode output", model(input_))
model.eval()
print("eval mode output", model(input_))

Terminal output:

train mode output tensor([[[[0.3603]],

         [[0.5518]],

         [[0.4599]],

         ...,

         [[0.3381]],

         [[0.4445]],

         [[0.3481]]]], grad_fn=<MeanBackward1>)
eval mode output tensor([[[[0.1582]],

         [[0.1822]],

         [[0.0000]],

         ...,

         [[0.0567]],

         [[0.0054]],

         [[0.3605]]]], grad_fn=<MeanBackward1>)

As you can see, the output of the two modes are very different from each other. Would this damage the performance?


Solution

  • This is caused by batchnorm. Batchnorm behaves differently in train mode vs eval mode.

    Batchnorm tracks the mean and variance of each batch run through the model and uses those values to compute a running mean and running variance of all batches.

    In train mode, batchnorm normalizes with the current batch stats.

    In eval mode, batchnorm normalizes with the running mean and running variance.

    Your model is based off a pre-trained imagenet model. This means that when the model is in eval mode, the batchnorm layers use statistics they computed from training on imagenet.

    When the model is in train mode, the batchnorm layers use in-batch statistics computed on the random input you pass to the model.

    The random input has very different mean/var stats compared to imagenet, so you see a large difference.

    If you fine-tune this model on whatever dataset you plan to use, then do a train/eval comparison on a real image from that dataset, you will see a smaller deviation between the outputs.