Search code examples
pytorchhuggingface-transformers

torch crossentropy loss calculation difference between 2D input and 3D input


i am running a test on torch.nn.CrossEntropyLoss. I am using the example shown on the official page.

loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=False)
target = torch.randn(3, 5).softmax(dim=1)
output = loss(input, target)

the output is 2.05. in the example, both the input and the target are 2D tensors. Since in most NLP case, the input should be 3D tensor and correspondingly the output should be 3D tensor as well. Therefore, i wrote the a couple lines of testing code, and found a weird issue.

input = torch.stack([input])
target = torch.stack([target])
output = loss(ins, ts)

the output is 0.9492 This result really confuse me, except the dimensions, the numbers inside the tensors are totally the same. Does anyone know the reason why the difference is?

the reason why i am testing on the method is i am working on project with Transformers.BartForConditionalGeneration. the loss result is given in the output, which is always in (1,) shape. the output is confusing. If my batch size is greater than 1, i am supposed to get batch size number of loss instead of just one. I took a look at the code, it just simply use nn.CrossEntropyLoss(), so i am considering that the issue may be in the nn.CrossEntropyLoss() method. However, it is stucked in the method.


Solution

  • In the second case, you are adding an extra dimension which means that ultimately, the softmax on the logits tensor (input) won't be applied on a different dimension.

    Here we compute the two quantities separately:

    >>> loss = nn.CrossEntropyLoss()
    >>> input = torch.randn(3, 5, requires_grad=False)
    >>> target = torch.randn(3, 5).softmax(dim=1)
    

    First you have loss(input, target) which is identical to:

    >>> o = -target*F.log_softmax(input, 1)
    >>> o.sum(1).mean()
    

    And your second scenario, loss(input[None], target[None]), identical to:

    >>> o = -target[None]*F.log_softmax(input[None], 1)
    >>> o.sum(1).mean()