Search code examples
pythonneural-networkdeep-learningpytorchbatch-normalization

How to do fully connected batch norm in PyTorch?


torch.nn has classes BatchNorm1d, BatchNorm2d, BatchNorm3d, but it doesn't have a fully connected BatchNorm class? What is the standard way of doing normal Batch Norm in PyTorch?


Solution

  • Ok. I figured it out. BatchNorm1d can also handle Rank-2 tensors, thus it is possible to use BatchNorm1d for the normal fully-connected case.

    So for example:

    import torch.nn as nn
    
    
    class Policy(nn.Module):
    def __init__(self, num_inputs, action_space, hidden_size1=256, hidden_size2=128):
        super(Policy, self).__init__()
        self.action_space = action_space
        num_outputs = action_space
    
        self.linear1 = nn.Linear(num_inputs, hidden_size1)
        self.linear2 = nn.Linear(hidden_size1, hidden_size2)
        self.linear3 = nn.Linear(hidden_size2, num_outputs)
        self.bn1 = nn.BatchNorm1d(hidden_size1)
        self.bn2 = nn.BatchNorm1d(hidden_size2)
    
    def forward(self, inputs):
        x = inputs
        x = self.bn1(F.relu(self.linear1(x)))
        x = self.bn2(F.relu(self.linear2(x)))
        out = self.linear3(x)
    
    
        return out