Search code examples

manually computing cross entropy loss in pytorch

I am trying to compute cross_entropy loss manually in Pytorch for an encoder-decoder model.

I used the code posted here to compute it: Cross Entropy in PyTorch

I updated the code to discard padded tokens (-100). The final code is this:

class compute_crossentropyloss_manual:
    y0 is the vector with shape (batch_size,C)
    x shape is the same (batch_size), whose entries are integers from 0 to C-1
    def __init__(self, ignore_index=-100) -> None:
    def __call__(self, y0, x):
        loss = 0.
        n_batch, n_class = y0.shape
        # print(n_class)
        for y1, x1 in zip(y0, x):
            class_index = int(x1.item())
            if class_index == self.ignore_index:  # <------ I added this if-statement
            loss = loss + torch.log(torch.exp(y1[class_index])/(torch.exp(y1).sum()))
        loss = - loss/n_batch
        return loss

To verify that it works fine, I tested it on a text generation task, and I computed the loss using pytorch.nn implementation and using this code.

The loss values are not identical:

using nn.CrossEntropyLoss:

enter image description here

Using the code from the link above:

enter image description here

Am I missing something?

I tried to get the source code of nn.CrossEntropyLoss but I wasn't able. In this link nn/ at line 2955, you will see that the function points to another cross_entropy loss called torch._C._nn.cross_entropy_loss; I can't find this function in the repo.


I noticed that the differences appear only when I have -100 tokens in the gold.

Demo example:

y = torch.randint(1, 50, (100, 50), dtype=float)
x = torch.randint(1, 50, (100,))

x[40:] = -100
print(criterion(y, x).item())
print(criterion2(y, x).item())
> 25.55788695847976
> 10.223154783391905

and when we don't have -100:

x[40:] = 30 # any positive number
print(criterion(y, x).item())
print(criterion2(y, x).item())
> 24.684453267596453
> 24.684453267596453


  • I solved the problem by updating the code. I discarded before the -100 tokens (the if-statement above), but I forgot to reduce the hidden_state size (which is called n_batch in the code above). After doing that, the loss numbers are identical to the nn.CrossEntropyLoss values. The final code:

    class CrossEntropyLossManual:
        y0 is the vector with shape (batch_size,C)
        x shape is the same (batch_size), whose entries are integers from 0 to C-1
        def __init__(self, ignore_index=-100) -> None:
        def __call__(self, y0, x):
            loss = 0.
            n_batch, n_class = y0.shape
            # print(n_class)
            for y1, x1 in zip(y0, x):
                class_index = int(x1.item())
                if class_index == self.ignore_index:
                    n_batch -= 1
                loss = loss + torch.log(torch.exp(y1[class_index])/(torch.exp(y1).sum()))
            loss = - loss/n_batch
            return loss