Search code examples
pytorchbatch-normalization

Batchnormalization over which dimension?


Over which dimension do we calculate the mean and std? Is it over the hidden dimensions of the NN Layer, or over all the samples in the batch for every hidden dimension separately?

In the paper it says we normalize over the batch.

In torch.nn.BatchNorm1d however the input argument is num_features, which makes no sense to me.

Why does pytorch not follow the original paper on Batchnormalization?


Solution

  • over which dimension do we calculate the mean and std?

    Over 0th dimension, for 1D input of shape (batch, num_features) it would be:

    batch = 64
    features = 12
    data = torch.randn(batch, features)
    
    mean = torch.mean(data, dim=0)
    var = torch.var(data, dim=0)
    

    In torch.nn.BatchNorm1d hower the input argument is "num_features", which makes no sense to me.

    It is not related to normalization but reparametrization of mean and var via gamma and beta learnable parameters. From the paper:

    batchnorm

    Both parameters used in scale and shift phase are of shape num_features, hence we have to pass this value in order to initialize them with specific shape.

    Below is an example from scratch implementation for reference:

    class BatchNorm1d(torch.nn.Module):
        def __init__(self, num_features, momentum: float = 0.9, eps: float = 1e-7):
            super().__init__()
            self.num_features = num_features
    
            self.gamma = torch.nn.Parameter(torch.ones(1, self.num_features))
            self.beta = torch.nn.Parameter(torch.zeros(1, self.num_features))
            
            self.register_buffer("running_mean", torch.ones(1, self.num_features))
            self.register_buffer("running_var", torch.ones(1, self.num_features))
    
            self.momentum = momentum
            self.eps = eps
    
        def forward(self, X):
            if not self.training:
                X_hat = X - self.running_mean / torch.sqrt(self.running_var + self.eps)
            else:
                mean = X.mean(dim=0).unsqueeze(dim=0)
                var = ((X - mean) ** 2).mean(dim=0).unsqueeze(dim=0)
    
                # Update running mean and variance
                self.running_mean *= self.momentum
                self.running_mean += (1 - self.momentum) * mean
    
                self.running_var *= self.momentum
                self.running_var += (1 - self.momentum) * var
    
                X_hat = X - mean / torch.sqrt(var + self.eps)
    
            return X_hat * self.gamma + self.beta
    

    Why does pytorch not follow the original paper on Batchnormalization?

    It does as one can see