Search code examples
pythondeep-learningpytorchbatch-normalizationfaster-rcnn

Not able to switch off batch norm layers for faster-rcnn (PyTorch)


I'm trying to switch off batch norm layers in a faster-rcnn model for evaluation mode.

I'm doing a sanity check atm:

@torch.no_grad()
def evaluate_loss(model, data_loader, device):
    val_loss = 0
    model.train()
    for images, targets in data_loader:
        # check that all layers are in train mode
        # for name, module in model.named_modules():
        #     if hasattr(module, 'training'):
        #         print('{} is training {}'.format(name, module.training))
        #         # set bn layers to eval
        for module in model.modules():
            if isinstance(module, torch.nn.BatchNorm2d):
                module.eval()
        # bn layers are now in eval
        for name, module in model.named_modules():
            if hasattr(module, 'training'):
                print('{} is training {}'.format(name, module.training))

However, all the batch norm layers are still in training mode. When I replace it with for example Conv2d, I get the expected behaviour of False. Here is an example snippet of the output:

backbone.body.layer4.0.conv1 is training True
backbone.body.layer4.0.bn1 is training True
backbone.body.layer4.0.conv2 is training True
backbone.body.layer4.0.bn2 is training True
backbone.body.layer4.0.conv3 is training True
backbone.body.layer4.0.bn3 is training True

Why is this happening? What can I do to switch off these layers? I have tried this with all variations of batch norm as provided by torch.nn.


Solution

  • So, after further investigation and after printing out all modules provided by the faster-rcnn, instead of BatchNorm2d, FrozenBatchNorm2d is used by the pretained model.

    Furthermore, unlike what's currently stated by the documentation, you must call torchvision.ops.misc.FrozenBatchNorm2d instead of torchvision.ops.FrozenBatchNorm2d.

    Additionally, as the layers are already frozen, there is no need to "switch off" these layers thus model.eval() is probably not required.