I'm walking through the autograd part of pytorch tutorials. I have two questions:
grad_output
and assign it to grad_input
other than simple assignment during backpropagation?grad_input[input < 0] = 0
? Does it mean we don't update the gradient when input less than zero?Here's the code:
class MyReLU(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
"""
In the forward pass we receive a Tensor containing the input and return
a Tensor containing the output. ctx is a context object that can be used
to stash information for backward computation. You can cache arbitrary
objects for use in the backward pass using the ctx.save_for_backward method.
"""
ctx.save_for_backward(input)
return input.clamp(min=0)
@staticmethod
def backward(ctx, grad_output):
"""
In the backward pass we receive a Tensor containing the gradient of the loss
with respect to the output, and we need to compute the gradient of the loss
with respect to the input.
"""
input, = ctx.saved_tensors
grad_input = grad_output.clone()
grad_input[input < 0] = 0
return grad_input
Thanks a lot in advance.
Why do we need clone the grad_output and assign it to grad_input other than simple assignment during backpropagation?
tensor.clone()
creates a copy of tensor that imitates the original tensor's requires_grad
field. clone
is a way to copy the tensor while still keeping the copy as a part of the computation graph it came from.
So, grad_input
is part of the same computation graph as grad_output
and if we compute the gradient for grad_output
, then the same will be done for grad_input
.
Since we make changes in grad_input
, we clone it first.
What's the purpose of 'grad_input[input < 0] = 0'? Does it mean we don't update the gradient when input less than zero?
This is done as per ReLU function's definition. The ReLU function is f(x)=max(0,x)
. It means if x<=0
then f(x)=0
, else f(x)=x
. In the first case, when x<0
, the derivative of f(x)
with respect to x
is f'(x)=0
. So, we perform grad_input[input < 0] = 0
. In the second case, it is f'(x)=1
, so we simply pass the grad_output
to grad_input
(works like an open gate).