Search code examples
deep-learningpytorchcross-entropy

In pytorch, how to use the weight parameter in F.cross_entropy()?


I'm trying to write some code like below:

x = Variable(torch.Tensor([[1.0,2.0,3.0]]))
y = Variable(torch.LongTensor([1]))
w = torch.Tensor([1.0,1.0,1.0])
F.cross_entropy(x,y,w)
w = torch.Tensor([1.0,10.0,1.0])
F.cross_entropy(x,y,w)

However, the output of cross entropy loss is always 1.4076 whatever w is. What is behind the weight parameter for F.cross_entropy()? How to use it correctly?
I'm using pytorch 0.3


Solution

  • The weight parameter is used to compute a weighted result for all inputs based on their target class. If you have only one input or all inputs of the same target class, weight won't impact the loss.

    See the difference however with 2 inputs of different target classes:

    import torch
    import torch.nn.functional as F
    from torch.autograd import Variable
    
    x = Variable(torch.Tensor([[1.0,2.0,3.0], [1.0,2.0,3.0]]))
    y = Variable(torch.LongTensor([1, 2]))
    w = torch.Tensor([1.0,1.0,1.0])
    res = F.cross_entropy(x,y,w)
    # 0.9076
    w = torch.Tensor([1.0,10.0,1.0])
    res = F.cross_entropy(x,y,w)
    # 1.3167