Search code examples
pytorchautograd

Activation gradient penalty


Here's a simple neural network, where I’m trying to penalize the norm of activation gradients:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=5)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
        self.pool = nn.MaxPool2d(2, 2)
        self.relu = nn.ReLU()
        self.linear = nn.Linear(64 * 5 * 5, 10)

    def forward(self, input):
        conv1 = self.conv1(input)
        pool1 = self.pool(conv1)
        self.relu1 = self.relu(pool1)
        self.relu1.retain_grad()
        conv2 = self.conv2(relu1)
        pool2 = self.pool(conv2)
        relu2 = self.relu(pool2)
        self.relu2 = relu2.view(relu2.size(0), -1)
        self.relu2.retain_grad()
        return self.linear(relu2)

model = Net()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

for i in range(1000):
    output = model(input)
    loss = nn.CrossEntropyLoss()(output, label)
    optimizer.zero_grad()
    loss.backward(retain_graph=True)

    grads = torch.autograd.grad(loss, [model.relu1, model.relu2], create_graph=True)

    grad_norm = 0
    for grad in grads:
        grad_norm += grad.pow(2).sum()

    grad_norm.backward()

    optimizer.step()

However, it does not produce the desired regularization effect. If I do the same thing for weights (instead of activations), it works well. Am I doing this right (in terms of pytorch machinery)? Specifically, what happens in grad_norm.backward() call? I just want to make sure the weight gradients are updated, and not activation gradients. Currently, when I print out gradients for weights and activations immediately before and after that line, both change - so I’m not sure what’s going on.


Solution

  • I think your code ends up computing some of the gradients twice in each step. I also suspect it actually never zeroes out the activation gradients, so they accumulate across steps.

    In general:

    • x.backward() computes gradient of x wrt. computation graph leaves (e.g. weight tensors and other variables), as well as wrt. nodes explicitly marked with retain_grad(). It accumulates the computed gradient in tensors' .grad attributes.

    • autograd.grad(x, [y, z]) returns gradient of x wrt. y and z regardless of whether they would normally retain grad or not. By default, it will also accumulate gradient in all leaves' .grad attributes. You can prevent this by passing only_inputs=True.

    I prefer to use backward() only for the optimization step, and autograd.grad() whenever my goal is to obtain "reified" gradients as intermediate values for another computation. This way, I can be sure that no unwanted gradients remain lying around in tensors' .grad attributes after I'm done with them.

    import torch
    from torch import nn
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(3, 32, kernel_size=5)
            self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
            self.pool = nn.MaxPool2d(2, 2)
            self.relu = nn.ReLU()
            self.linear = nn.Linear(64 * 5 * 5, 10)
    
        def forward(self, input):
            conv1 = self.conv1(input)
            pool1 = self.pool(conv1)
            self.relu1 = self.relu(pool1)
            conv2 = self.conv2(self.relu1)
            pool2 = self.pool(conv2)
            self.relu2 = self.relu(pool2)
            relu2 = self.relu2.view(self.relu2.size(0), -1)
            return self.linear(relu2)
    
    
    model = Net()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
    grad_penalty_weight = 10.
    
    for i in range(1000000):
        # Random input and labels; we're not really learning anything
        input = torch.rand(1, 3, 32, 32)
        label = torch.randint(0, 10, (1,))
    
        output = model(input)
        loss = nn.CrossEntropyLoss()(output, label)
    
        # This is where the activation gradients are computed
        # only_inputs is optional here, since we're going to call optimizer.zero_grad() later
        # But it makes clear that we're *only* interested in the activation gradients at this point
        grads = torch.autograd.grad(loss, [model.relu1, model.relu2], create_graph=True, only_inputs=True)
    
        grad_norm = 0
        for grad in grads:
            grad_norm += grad.pow(2).sum()
    
        optimizer.zero_grad()
        loss = loss + grad_norm * grad_penalty_weight
        loss.backward()
        optimizer.step()
    

    This code appears to work, in that the activation gradients do get smaller. I cannot comment on the viability of this technique as a regularization method.