Search code examples
pythonpytorchloss-function

The Loss can’t back propagate to model’s parameters with my customized loss function


I designed a customized loss:

class CustomIndicesEdgeAccuracyLoss(torch.nn.Module):
    def __init__(self, num_classes: int, selected_indices: list):
        super(CustomIndicesEdgeAccuracyLoss, self).__init__()
        self.num_classes = num_classes
        self.selected_indices = selected_indices

    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        batch_size, num_classes, feature_size = input.shape
        selected_input = input[::, ::, self.selected_indices]
        selected_target = target[::, self.selected_indices]
        selected_preds = torch.argmax(selected_input, dim=1)
        edge_acc = torch.eq(selected_preds, selected_target).sum()/torch.numel(selected_preds)
        loss = 1 – edge_acc
        loss.requires_grad = True

        return loss

But the loss won’t back propagate to model’s parameters, in other word, the gradient of model’s parameters are always 0 and the model’s parameters can’t be updated. What’s possible reasons? How should I revise the codes?

Here is some information of the local variables of forward():

input.shape: torch.Size([64, 3, 5])
target.shape:torch.Size([64, 5])
selected_input.shape: torch.Size([64, 3, 2]) 
selected_target.shape:torch.Size([64, 2])

PS. I ask the same question in here, so I will copy the answer from this post to there, and vice versa.


Solution

  • You can't use accuracy as a loss function as it is non-differentiable. There is no gradient propagation through torch.argmax. You need to use a differentiable loss function.