Search code examples
pythonpytorchcross-entropy

pytorch cross-entropy-loss weights not working


I was playing around with some code and and it behaved differently than what i expected. So i dumbed it down to a minimally working example:

import torch

test_act = torch.tensor([[2.,0.]])
test_target = torch.tensor([0])

loss_function_test = torch.nn.CrossEntropyLoss()
loss_test = loss_function_test(test_act, test_target)
print(loss_test)
> tensor(0.1269)

weights=torch.tensor([0.1,0.5])
loss_function_test = torch.nn.CrossEntropyLoss(weight=weights)
loss_test = loss_function_test(test_act, test_target)
print(loss_test)
> tensor(0.1269)

As you can see the outputs are the same regardless if there are weights present or not. But i would expect the second output to be 0.0127

Is there some normalization going on that I dont know about? Or is it possibly bugged?


Solution

  • In this example, I add a second datum with a different target class, and the effect of weights is visible.

    import torch
    
    test_act = torch.tensor([[2.,1.],[1.,4.]])
    test_target = torch.tensor([0,1])
    
    loss_function_test = torch.nn.CrossEntropyLoss()
    loss_test = loss_function_test(test_act, test_target)
    print(loss_test)
    >>> tensor(0.1809)
    
    
    weights=torch.tensor([0.1,0.5])
    loss_function_test = torch.nn.CrossEntropyLoss(weight=weights)
    loss_test = loss_function_test(test_act, test_target)
    print(loss_test)
    >>> tensor(0.0927)
    

    This effect is because "The losses are averaged across observations for each minibatch. If the weight argument is specified then this is a weighted average" but only across the minibatch.

    Personally I find this a bit strange and would think it would be useful to apply the weights globally (i.e. even if all classes are not present in each minibatch). One of the prominent uses of the weight parameter would ostensibly be to give more weight to classes that are under-represented in the dataset, but by this formulation the minority classes are only given higher weights for the minibatches in which they are present (which, of course, is a low percentage because they are a minority class).

    In any case that is how Pytorch defines this operation.