Search code examples
machine-learningoptimizationneural-networkpytorchbackpropagation

How to register a dynamic backward hook on tensors in Pytorch?


I'm trying to register a backward hook on each neuron's weights in a network. By dynamic I mean that it will take a value and multiply the associated gradients by that value.

From here it seem like it's possible to register a hook on a tensor with a fixed value (though note that I need it to take a value that will change). From here it also seems like it's possible to register a hook on all of the parameters -- they use it to do gradients clipping (though note that I'm trying to only do it on each neuron's weights).

If my network is as follows:

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()

        self.fc1 = nn.Linear(3,5)
        self.fc2 = nn.Linear(5,10)
        self.fc3 = nn.Linear(10,1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        return x 

The first layer has 5 neurons with 3 associated weights for each. Hence, this layer should have 5 hooks that modifies (i.e change the current gradient by multiplying it) their 3 associated weights gradients during the backward step.

Training pseudo-code example:

net = Model()
for epoch in epochs:
    out = net(data)
    loss = criterion(out, target)
    optimizer.zero_grad()
    loss.backward()
    for hook in list_of_hooks: #not sure if there's a more "pytorch" way of doing this without a for loop
        hook(random_value)
    optimizer.step()

Solution

  • What about exploiting lambdas closure over names?

    A short example:

    import torch
    
    net_params = torch.rand(5, 3, requires_grad=True)
    
    msg = "Hello!"
    
    t.register_hook(lambda g: print(msg))
    
    
    out1 = net_params * 2.
    
    loss = out1.sum()
    loss.backward()  # Activates the hook and prints "Hello!"
    
    
    msg = "How are you?"  # The lambda is affected by this change
    
    out2 = t ** 4.
    loss2 = out2.sum()
    
    loss2.backward()  # Activates the hook again and prints "How are you?"
    

    So a possible solution to your problem:

    net = Model()
    # Replace it with your computed values
    rand_values = torch.rand(net.fc1.out_features, net.fc1.in_features)
    
    net.fc1.weight.register_hook(lambda g: g * rand_values) 
    
    for epoch in epochs:
        out = net(data)
        loss = criterion(out, target)
        optimizer.zero_grad()
        loss.backward()  # fc1 gradients are multiplied by rand_values
        optimizer.step()
    
        # Update rand_values. The lambda computation will change accordingly
        rand_values = torch.rand(net.fc1.out_features, net.fc1.in_features)
    

    Edit

    To make things clearer, if you specifically want to multiply each set of weights i by a single value vi you can exploit broadcasting semantic and define values = torch.tensor([v0, v1, v2, v3, v4]).reshape(5, 1), then the lambda becomes lambda g: g * values