Search code examples
pythondeep-learningpytorchneural-network

Scale gradient for specific parameter in pytorch?


Suppose I have some neural network with parameters A,B,C. Whenever a gradient update is applied to C, I want it to be scaled differently than what the normal gradient would be (i.e. every gradient update is 2x or 1/3x what the calculated gradient is). I want the gradients applied to all other parameters to stay the same. How would I go about this in pytorch?


Solution

  • You can directly access the gradient of your tensor by using the grad attribute: C.grad *= z, where z would be your multiplicative factor.

    One way to wrap this neatly is to use a backward hook with register_full_backward_hook on your module. Here is a minimal use case corresponding to your description:

    class CustomFunction(nn.Module):
        def __init__(self):
            super().__init__()
            self.A = nn.Parameter(torch.rand(1))
            self.B = nn.Parameter(torch.rand(1))
            self.C = nn.Parameter(torch.rand(1))
    
        def forward(self, x):
            return (x*self.A+self.B)**self.C
    

    Let's set a random input:

    >>> torch.rand(1, requires_grad=True)
    

    You can define a hook that will be called when the backward function is called on that module. To do so, you can directly tap into the properties of the module passed as an argument:

    def hook(module, grad_input, grad_output):
        module.C.grad *= z
    

    Then you attach the hook on an instance of your module with register_full_backward_hook:

    >>> torch.manual_seed(0)
    >>> F = CustomFunction()
    >>> F.register_full_backward_hook(hook)
    

    The gradients are then:

    >>> F(x).backward()
    >>> F.A.grad, F.B.grad, F.C.grad
    (tensor([0.0138]), tensor([0.1044]), tensor([-0.5368]))
    

    This can be compared with an inference without the hook, as:

    >>> torch.manual_seed(0)
    >>> F = CustomFunction()
    >>> F(x).backward()
    >>> F.A.grad, F.B.grad, F.C.grad
    (tensor([0.0138]), tensor([0.1044]), tensor([-0.1789]))