Search code examples
pytorchcross-entropy

cross entropy loss and torch weights mismatch


My targets are primarily class 0, less frequently class 1 or 2 Trying to do cross entropy loss with class weights The following code

weights = torch.tensor([1., 10, 10.]).to(device)
lossfn = nn.CrossEntropyLoss(weight=weights) 
pred = model(input1, input2)
target = labelarray.type(torch.LongTensor).to(device) 
loss = lossfn(pred, target)

produces the following error

RuntimeError: weight tensor should be defined either for all 1 classes or no classes but got weight tensor of shape: [3]

B (batch size) is 128;

pred is ~ torch.Size([B, 1])

target ~ torch.Size([B]) =

([0, 0, 0, 0, 0, 2, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 2, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 1, 2, 1, 0, 0, 2, 1, 0, 0, 2, 1, 2, 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 2, 2, 2, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 1, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], device='cuda:0')

Clearly, my target has 3 classes , the error suggests pytorch sees just 1 class


Solution

  • Your prediction is the wrong size. The prediction tensor should be of size (bs, n_classes). Since you have three classes, your prediction should be of shape (bs, 3).

    If your prediction is shape (bs, 1), the softmax of the class dimension will return 1. for every value. No matter what your model does, the output of shape (bs, 1) will be interpreted as predicting class 0 with 100% confidence due to the softmax.

    To predict three classes, your output should be of shape (bs, 3). With the correct shape, the loss weights work as predicted:

    bs = 32
    n_classes = 3
    preds = torch.randn(bs, n_classes)
    targs = torch.randint(0, high=n_classes, size=(bs,))
    weights = torch.tensor([1., 10., 10.])
    loss = nn.CrossEntropyLoss(weight=weights)
    loss(preds, targs)
    > tensor(1.5492)