Search code examples
pythonpytorchloss-functioncross-entropy

cross entropy loss with weight manual calculation


Hi just playing around with code, I got the unexpected result of cross entropy loss weight implementation.

pred=torch.tensor([[8,5,3,2,6,1,6,8,4],[2,5,1,3,4,6,2,2,6],[1,1,5,8,9,2,5,2,8],[2,2,6,4,1,1,7,8,3],[2,2,2,7,1,7,3,4,9]]).float()
label=torch.tensor([[3],[7],[8],[2],[5]],dtype=torch.int64)
weights=torch.tensor([1,1,1,10,1,6,1,1,1],dtype=torch.float32)

with this kind of sample variables, pytorch's cross entropy loss gives out 4.7894

loss = F.cross_entropy(pred, label, weight=weights,reduction='mean')
> 4.7894

I manually implemented the cross entropy loss code as below

one_hot = torch.zeros_like(pred).scatter(1, label.view(-1, 1), 1)
log_prb = F.log_softmax(pred, dim=1)
loss = -(one_hot * log_prb).sum(dim=1).mean()

this kind of implementation gives same result with pytorch's cross entropy function if given without weight value. However with weight value

one_hot = torch.zeros_like(pred).scatter(1, label.view(-1, 1), 1)
log_prb = F.log_softmax(pred, dim=1)
loss = -(one_hot * log_prb)*weights.sum(dim=1).sum()/weights.sum()
> 3.9564

it gives out different loss value with pytorch module(4.7894). I can roughly estimate that my understanding of loss's weight have some problem here, but I can't find out the exact reason for this kind of discrepancy. Can anybody help me handling this issue?


Solution

  • I found out the problem. It was quite simple... I shouldn't have divided with the whole sum of weights. Instead with dividing with wt.sum() (wt=one_hot*weight) got me 4.7894.

    >>> wt = one_hot*weights
    >>> loss = -(one_hot * log_prb * weights).sum(dim=1).sum() / wt.sum()
    4.7894
    

    The denominator was only with 'related' weight value, not whole.