Search code examples
pythonmathpytorchcross-entropy

pos_weight in binary cross entropy calculation


When we deal with imbalanced training data (there are more negative samples and less positive samples), usually pos_weight parameter will be used. The expectation of pos_weight is that the model will get higher loss when the positive sample gets the wrong label than the negative sample. When I use the binary_cross_entropy_with_logits function, I found:

bce = torch.nn.functional.binary_cross_entropy_with_logits

pos_weight = torch.FloatTensor([5])
preds_pos_wrong =  torch.FloatTensor([0.5, 1.5])
label_pos = torch.FloatTensor([1, 0])
loss_pos_wrong = bce(preds_pos_wrong, label_pos, pos_weight=pos_weight)

preds_neg_wrong =  torch.FloatTensor([1.5, 0.5])
label_neg = torch.FloatTensor([0, 1])
loss_neg_wrong = bce(preds_neg_wrong, label_neg, pos_weight=pos_weight)

However:

>>> loss_pos_wrong
tensor(2.0359)

>>> loss_neg_wrong
tensor(2.0359)

The losses derived from wrong positive samples and negative samples are the same, so how does pos_weight work in the imbalanced data loss calculation?


Solution

  • TLDR; both losses are identical because you are computing the same quantity: both inputs are identical, the two batch elements and labels are just switched.


    Why are you getting the same loss?

    I think you got confused in the usage of F.binary_cross_entropy_with_logits (you can find a more detailed documentation page with nn.BCEWithLogitsLoss). In your case your input shape (aka the output of your model) is one-dimensional, which means you only have a single logit x, not two).

    In your example you have

    preds_pos_wrong = torch.FloatTensor([0.5, 1.5])
    label_pos = torch.FloatTensor([1, 0])
    

    This means your batch size is 2, and since by default the function is averaging the losses of the batch elements, you end up with the same result for BCE(preds_pos_wrong, label_pos) and BCE(preds_neg_wrong, label_neg). The two elements of your batch are just switched.

    You can verify this very easily by not averaging the loss over the batch-elements with the reduction='none' option:

    >>> F.binary_cross_entropy_with_logits(preds_pos_wrong, label_pos, 
           pos_weight=pos_weight, reduction='none')
    tensor([2.3704, 1.7014])
    
    >>> F.binary_cross_entropy_with_logits(preds_pos_wrong, label_pos, 
           pos_weight=pos_weight, reduction='none')
    tensor([1.7014, 2.3704])
    

    Looking into F.binary_cross_entropy_with_logits:

    That being said the formula for the binary cross-entropy is:

    bce = -[y*log(sigmoid(x)) + (1-y)*log(1- sigmoid(x))]
    

    Where y (respectively sigmoid(x) is for the positive class associated with that logit, and 1 - y (resp. 1 - sigmoid(x)) is the negative class.

    The documentation could be more precise on the weighting scheme for pos_weight (not to be confused with weight, which is the weighting of the different logits output). The idea with pos_weight as you said, is to weigh the positive term, not the whole term.

    bce = -[w_p*y*log(sigmoid(x)) + (1-y)*log(1- sigmoid(x))]
    

    Where w_p is the weight for the positive term, to compensate for the positive to negative sample imbalance. In practice, this should be w_p = #negative/#positive.

    Therefore:

    >>> w_p = torch.FloatTensor([5])
    >>> preds = torch.FloatTensor([0.5, 1.5])
    >>> label = torch.FloatTensor([1, 0])
    

    With the builtin loss function,

    >>> F.binary_cross_entropy_with_logits(preds, label, pos_weight=w_p, reduction='none')
    tensor([2.3704, 1.7014])
    

    Compared with the manual computation:

    >>> z = torch.sigmoid(preds)
    >>> -(w_p*label*torch.log(z) + (1-label)*torch.log(1-z))
    tensor([2.3704, 1.7014])