Search code examples
pytorch

why batchnorm in pytorch gives different output in eval mode


when using pytorch BatchNorm module, in the below example shouldn't out_1 be equal to out_2 because it calculated out_1 with batch statistics and out_2 using the running mean but with only one batch?

import torch

test = torch.rand((2,10))

norm = torch.nn.BatchNorm1d(10)

out_1 = norm(test)
norm.train(False)
out_2 = norm(test)

print(f"using batch statistics: {out_1}")

print(f"using moving average: {out_2}")

Solution

  • There is an open issue on the pytorch github repository here, which seems to describe the issue you are having:

    https://github.com/pytorch/pytorch/issues/100048