Search code examples
pytorchloss-functioncross-entropy

'RuntimeError: Expected object of scalar type Long but got scalar' for torch.nn.CrossEntropyLoss()


I'm using this loss function for xlm-roberta-large-longformer and it gives me this error:

    import torch.nn.functional as f
    from scipy.special import softmax
    
    loss_func = torch.nn.CrossEntropyLoss()
    output = torch.softmax(logits.view(-1,num_labels), dim=0).float()
    target = b_labels.type_as(logits).view(-1,num_labels)
    loss = loss_func(output, target)
    train_loss_set.append(loss.item()) 

when I try

b_labels.type_as(logits).view(-1,num_labels).long()

it tells me

RuntimeError: multi-target not supported at /pytorch/aten/src/THCUNN/generic/ClassNLLCriterion.cu:15

What should I do?


Solution

  • Your target tensor should contain integers corresponding to the correct class labels and should not be a one/multi-hot encoding of the class.

    You can extract the class labels from a one-hot encoding format using argmax:

    >>> b_labels.argmax(1)