I have a custom forward
implementation for a PyTorch loss. The training works well. I've checked the loss.grad_fn
and it is not None
.
I'm trying to understand two things:
How this function can be differentiable since there is an if
-else
statement on the path from input to output?
Does the path from gt
(ground truth input) to loss (output) need to be differentiable? or only the path from pred
(prediction input)?
Here is the source code:
class FocalLoss(nn.Module):
def __init__(self):
super(FocalLoss, self).__init__()
def forward(self, pred, gt):
pos_inds = gt.eq(1).float()
neg_inds = gt.lt(1).float()
neg_weights = torch.pow(1 - gt, 4)
pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds
num_pos = pos_inds.float().sum()
pos_loss_s = pos_loss.sum()
neg_loss_s = neg_loss.sum()
if num_pos == 0:
loss = - neg_loss_s
else:
loss = - (pos_loss_s + neg_loss_s) / num_pos
return loss
The if
statement is not part of the computational graph. It is part of the code used to build this graph dynamically (i.e. the forward
function) but it isn't in itself part of it. The principle to follow is to ask yourself whether you backtrack to the leaves of the graph (tensors that do not have parents in the graph, i.e. inputs, and parameters) using grad_fn
callbacks of each node, backpropagating through the graph. The answer is you can only do so if each of the operators is differentiable: in programming terms, they implement a backward function operation (a.k.a. grad_fn
).
In your example, whether num_pos
is equal to 0
or not, the resulting loss tensor will depend on neg_loss_s
alone or on pos_loss_s
and neg_loss_s
. However in either cases, the resulting loss
tensor remains attached to the input pred
:
neg_loss_s
" nodepos_loss_s
" and "neg_loss_s
" nodes.In your setup, either way, the operation is differentiable.
gt
is a ground-truth tensor then it doesn't require gradient and the operation from it to the final loss doesn't need to be differentiable. This is the case in your example where both pos_inds
, and neg_inds
are non-differientblae because they are boolean operators.