Search code examples
pytorchbackpropagationautograd

Why do we need clone the grad_output and assign it to grad_input when defining a ReLU autograd function?


I'm walking through the autograd part of pytorch tutorials. I have two questions:

  1. Why do we need clone the grad_output and assign it to grad_input other than simple assignment during backpropagation?
  2. What's the purpose of grad_input[input < 0] = 0? Does it mean we don't update the gradient when input less than zero?

Here's the code:

class MyReLU(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        """
        In the forward pass we receive a Tensor containing the input and return
        a Tensor containing the output. ctx is a context object that can be used
        to stash information for backward computation. You can cache arbitrary
        objects for use in the backward pass using the ctx.save_for_backward method.
        """
        ctx.save_for_backward(input)
        return input.clamp(min=0)

    @staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass we receive a Tensor containing the gradient of the loss
        with respect to the output, and we need to compute the gradient of the loss
        with respect to the input.
        """
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input

Link here: https://pytorch.org/tutorials/beginner/pytorch_with_examples.html#pytorch-defining-new-autograd-functions

Thanks a lot in advance.


Solution

  • Why do we need clone the grad_output and assign it to grad_input other than simple assignment during backpropagation?

    tensor.clone() creates a copy of tensor that imitates the original tensor's requires_grad field. clone is a way to copy the tensor while still keeping the copy as a part of the computation graph it came from.

    So, grad_input is part of the same computation graph as grad_output and if we compute the gradient for grad_output, then the same will be done for grad_input.

    Since we make changes in grad_input, we clone it first.

    What's the purpose of 'grad_input[input < 0] = 0'? Does it mean we don't update the gradient when input less than zero?

    This is done as per ReLU function's definition. The ReLU function is f(x)=max(0,x). It means if x<=0 then f(x)=0, else f(x)=x. In the first case, when x<0, the derivative of f(x) with respect to x is f'(x)=0. So, we perform grad_input[input < 0] = 0. In the second case, it is f'(x)=1, so we simply pass the grad_output to grad_input (works like an open gate).