Search code examples
pythonmachine-learningdeep-learningneural-networkpytorch

How do I initialize weights in PyTorch?


How do I initialize weights and biases of a network (via e.g. He or Xavier initialization)?


Solution

  • Single layer

    To initialize the weights of a single layer, use a function from torch.nn.init. For instance:

    conv1 = torch.nn.Conv2d(...)
    torch.nn.init.xavier_uniform(conv1.weight)
    

    Alternatively, you can modify the parameters by writing to conv1.weight.data (which is a torch.Tensor). Example:

    conv1.weight.data.fill_(0.01)
    

    The same applies for biases:

    conv1.bias.data.fill_(0.01)
    

    nn.Sequential or custom nn.Module

    Pass an initialization function to torch.nn.Module.apply. It will initialize the weights in the entire nn.Module recursively.

    apply(fn): Applies fn recursively to every submodule (as returned by .children()) as well as self. Typical use includes initializing the parameters of a model (see also torch-nn-init).

    Example:

    def init_weights(m):
        if isinstance(m, nn.Linear):
            torch.nn.init.xavier_uniform(m.weight)
            m.bias.data.fill_(0.01)
    
    net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
    net.apply(init_weights)