Search code examples
pytorchcross-entropy

Why my cross entropy loss function does not converge?


I try to write a cross entropy loss function by myself. My loss function gives the same loss value as the official one, but when i use my loss function in the code instead of official cross entropy loss function, the code does not converge. When i use the official cross entropy loss function, the code converges. Here is my code, please give me some suggestions. Thanks very much The input 'out' is a tensor (B*C) and 'label' contains class indices (1 * B)

class MylossFunc(nn.Module): 
    def __init__(self):
        super(MylossFunc, self).__init__()
    def forward(self, out, label):
        out = torch.nn.functional.softmax(out, dim=1)
            n = len(label)
            loss = torch.FloatTensor([0])
            loss = Variable(loss, requires_grad=True)
            tmp = torch.log(out)
            #print(out)
            torch.scalar_tensor(-100)
            for i in range(n):
                loss = loss  - torch.max(tmp[i][label[i]], torch.scalar_tensor(-100) )/n
            loss = torch.sum(loss)
            return loss

Solution

  • Instead of using torch.softmax and torch.log, you should use torch.log_softmax, otherwise your training will become unstable with nan values everywhere.

    This happens because when you take the softmax of your logits using the following line:

    out = torch.nn.functional.softmax(out, dim=1)
    

    you might get a zero in one of the components of out, and when you follow that by applying torch.log it will result in nan (since log(0) is undefined). That is why torch (and other common libraries) provide a single stable operation, log_softmax, to avoid the numerical instabilities that occur when you use torch.softmax and torch.log individually.