Search code examples
pythonpytorchbatch-normalization

How to implement Batchnorm2d in Pytorch myself?


I'm trying to implement Batchnorm2d() layer with:

class BatchNorm2d(nn.Module):

    def __init__(self, num_features):
        super(BatchNorm2d, self).__init__()
        self.num_features = num_features
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.eps = 1e-5
        self.momentum = 0.1
        self.first_run = True

    def forward(self, input):
        # input: [batch_size, num_feature_map, height, width]
        device = input.device
        if self.training:
            mean = torch.mean(input, dim=0, keepdim=True).to(device)  # [1, num_feature, height, width]
            var = torch.var(input, dim=0, unbiased=False, keepdim=True).to(device)  # [1, num_feature, height, width]
            if self.first_run:
                self.weight = Parameter(torch.randn(input.shape, dtype=torch.float32, device=device), requires_grad=True)
                self.bias = Parameter(torch.randn(input.shape, dtype=torch.float32, device=device), requires_grad=True)
                self.register_buffer('running_mean', torch.zeros(input.shape).to(input.device))
                self.register_buffer('running_var', torch.ones(input.shape).to(input.device))
                self.first_run = False
            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
            self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
            bn_init = (input - mean) / torch.sqrt(var + self.eps)
        else:
            bn_init = (input - self.running_mean) / torch.sqrt(self.running_var + self.eps)
        return self.weight * bn_init + self.bias

But after training & testing I found that the results using my layer is incomparable with the results using nn.Batchnorm2d(). There must be something wrong with it, and I guess the problem relates to initializing parameters in forward()? I did that because I don't know how to know the shape of input in __init__(), maybe there is a better way. I don't know how to fix it, please help. Thanks!!


Solution

  • Got answers from HERE!\
    So the shape of weight(bias) is (1, num_features, 1, 1), not (1, num_features, width, height).