Search code examples
tensorflowpytorchautomatic-differentiation

Is there a PyTorch equivalent of tf.custom_gradient()?


I am new to PyTorch but have a lot of experience with TensorFlow.

I would like to modify the gradient of just a tiny piece of the graph: just the derivative of activation function of a single layer. This can be easily done in Tensorflow using tf.custom_gradient, which allows you to supply customized gradient for any functions.

I would like to do the same thing in PyTorch and I know that you can modify the backward() method, but that requires you to rewrite the derivative for the whole network defined in the forward() method, when I would just like to modify the gradient of a tiny piece of the graph. Is there something like tf.custom_gradient() in PyTorch? Thanks!


Solution

  • You can do this in two ways:

    1. Modifying the backward() function:
    As you already said in your question, also allows you to provide a custom backward implementation. However, in contrast to what you wrote, you do not need to re-write the backward() of the entire model - only the backward() of the specific layer you want to change.
    Here's a simple and nice tutorial that shows how this can be done.

    For example, here is a custom clip activation that instead of killing the gradients outside the [0, 1] domain, simply passes the gradients as-is:

    class MyClip(torch.autograd.Function):
        @staticmethod
        def forward(ctx, x):
            return torch.clip(x, 0., 1.)
    
        @staticmethod
        def backward(ctx, grad):
            return grad
    

    Now you can use MyClip layer wherever you like in your model and you do not need to worry about the overall backward function.


    2. Using a backward hook allows you to attach hooks to different layer (=sub nn.Modules) of your network. You can register_full_backward_hook to your layer. That hook function can modify the gradients:

    The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of grad_input in subsequent computations.