Search code examples
deep-learningpytorchmeanvariancebatch-normalization

What do BatchNorm2d's running_mean / running_var mean in PyTorch?


I'd like to know what exactly the running_mean and running_var that I can call from nn.BatchNorm2d.

Example code is here where bn means nn.BatchNorm2d.

vector = torch.cat([
    torch.mean(self.conv3.bn.running_mean).view(1), torch.std(self.conv3.bn.running_mean).view(1),
    torch.mean(self.conv3.bn.running_var).view(1), torch.std(self.conv3.bn.running_var).view(1),
    torch.mean(self.conv5.bn.running_mean).view(1), torch.std(self.conv5.bn.running_mean).view(1),
    torch.mean(self.conv5.bn.running_var).view(1), torch.std(self.conv5.bn.running_var).view(1)
])

I couldn't figure out what running_mean and running_var mean in the Pytorch official documentation and user community.

What do nn.BatchNorm2.running_mean and nn.BatchNorm2.running_var mean?


Solution

  • From the original Batchnorm paper:

    Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift,
    Seguey Ioffe and Christian Szegedy, ICML'2015

    You can see on Algorithm 1. how to measure the statistics of a given batch.

    enter image description here

    However what is kept in memory across batches is the running stats, i.e. the statistics which are measured iteratively at each batch inference. The computation of the running mean and running variance is actually quite well explained in the documentation page of nn.BatchNorm2d:

    enter image description here

    By default, the momentum coefficient is set to 0.1, it regulates how much of the current batch statistics will affect the running statistics:

    • closer to 1 means the new running stat is closer to the current batch statistics, whereas

    • closer to 0 means the current batch stats will not contribute much to updating the new running stats.

    It's worth pointing out that Batchnorm2d is applied across spatial dimensions, * in addition*, to the batch dimension of course. Given a batch of shape (b, c, h, w), it will compute the statistics across (b, h, w). This means the running statistics are shaped (c,), i.e. there are as many statistics components as there are in input channels (for both mean and variance).

    Here is a minimal example:

    >>> bn = nn.BatchNorm2d(10)
    >>> x = torch.rand(2,10,2,2)
    

    Since track_running_stats is set to True by default on BatchNorm2d, it will track the running stats when inferring on training mode.

    The running mean and variance are initialized to zeros and ones, respectively.

    >>> running_mean, running_var = torch.zeros(x.size(1)),torch.ones(x.size(1))
    

    Let's perform inference on bn in training mode and check its running stats:

    >>> bn(x)
    >>> bn.running_mean, bn.running_var
    (tensor([0.0650, 0.0432, 0.0373, 0.0534, 0.0476, 
             0.0622, 0.0651, 0.0660, 0.0406, 0.0446]),
     tensor([0.9027, 0.9170, 0.9162, 0.9082, 0.9087, 
             0.9026, 0.9136, 0.9043, 0.9126, 0.9122]))
    

    Now let's compute those stats by hand:

    >>> xmean = x.mean([0,2,3]) # Mean over batch, height, and width
    >>> xvar = x.var([0,2,3], unbiased=True)
    >>> (1-momentum)*running_mean + momentum*xmean
    tensor([[0.0650, 0.0432, 0.0373, 0.0534, 0.0476, 
             0.0622, 0.0651, 0.0660, 0.0406, 0.0446]])
    
    >>> (1-momentum)*running_var + momentum*xvar
    tensor([[0.9027, 0.9170, 0.9162, 0.9082, 0.9087, 
             0.9026, 0.9136, 0.9043, 0.9126, 0.9122]])