Search code examples
pytorchconv-neural-network

what does model.eval() do for batch normalization layer?


Why does the testing data use the mean and variance of the all training data? To keep the distribution consistent? What is the difference between the BN layer using model.train compared to model.val


Solution

  • It fixes the mean and var computed in the training phase by keeping estimates of it in running_mean and running_var. See PyTorch Documentation.

    As noted there the implementation is based on the description in the paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift. As one tries to use the whole training data one can get (given similar data train/test data) a better estimate of the mean/var for the (unseen) test set.

    Also similar questions have been asked here: What does model.eval() do?