Search code examples
deep-learningpytorchgradient-descent

Parameter-specific learning rate in PyTorch


How can I set a learning rate for each specific parameter (weights and biases) in a network?

On PyTorch's docs I found this:

optim.SGD([{'params': model.base.parameters()}, 
           {'params': model.classifier.parameters(), 'lr': 1e-3}], 
           lr=1e-2, momentum=0.9)

where model.classifier.parameters(), which defines a group of parameters obtains a specific learning rate of 1e-3.

But how can I translate this into parameter level?


Solution

  • You can set parameter-specific learning rate by using the parameter names to set the learning rates e.g.

    For a given network taken from PyTorch forum:

    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.layer1 = nn.Linear(1, 1)
            self.layer1.weight.data.fill_(1)
            self.layer1.bias.data.fill_(1)
            self.layer2 = nn.Linear(1, 1)
            self.layer2.weight.data.fill_(1)
            self.layer2.bias.data.fill_(1)
    
        def forward(self, x):
            x = self.layer1(x)
            return self.layer2(x)
    
    net = Net()
    for name, param in net.named_parameters():
        print(name)
    

    The parameters are:

    layer1.weight
    layer1.bias
    layer2.weight
    layer2.bias
    

    Then, you can use the parameter names to set their specific learning rates as follows:

    optimizer = optim.Adam([
                {'params': net.layer1.weight},
                {'params': net.layer1.bias, 'lr': 0.01},
                {'params': net.layer2.weight, 'lr': 0.001}
            ], lr=0.1, weight_decay=0.0001)
    
    out = net(torch.Tensor([[1]]))
    out.backward()
    optimizer.step()
    print("weight", net.layer1.weight.data.numpy(), "grad", net.layer1.weight.grad.data.numpy())
    print("bias", net.layer1.bias.data.numpy(), "grad", net.layer1.bias.grad.data.numpy())
    print("weight", net.layer2.weight.data.numpy(), "grad", net.layer2.weight.grad.data.numpy())
    print("bias", net.layer2.bias.data.numpy(), "grad", net.layer2.bias.grad.data.numpy())
    

    Output:

    weight [[0.9]] grad [[1.0001]]
    bias [0.99] grad [1.0001]
    weight [[0.999]] grad [[2.0001]]
    bias [1.] grad [1.]