Search code examples
pythondeep-learningpytorchloss-functionweighted

Multi-channel, 2D mask weights using BCEWithLogitsLoss in Pytorch


I have a set of 256x256 images that are each labeled with nine, binary 256x256 masks. I am trying to calculate the pos_weight in order to weight the BCEWithLogitsLoss using Pytorch.

The shape of my masks tensor is tensor([1000, 9, 256, 256]) where 1000 is the number of training images, 9 is the number of mask channels (all encoded to 0/1), and 256 is the size of each image side.

To calculate pos_weight, I have summed the zeros in each mask, and divided that number by the sum of all of the ones in each mask (following the advice suggested here.):

(masks[:,channel,:,:]==0).sum()/masks[:,channel,:,:].sum()

Calculating the weight for every mask channel provides a tensor with the shape of tensor([9]), which seems intuitive to me, since I want a pos_weight value for each of the nine mask channels. However when I try to fit my model, I get the following error message:

RuntimeError: The size of tensor a (9) must match the size of
tensor b (256) at non-singleton dimension 3

This error message is surprising because it suggests that the weights need to be the size of one of the image sides, but not the number of mask channels. What shape should pos_weight be and how do I specify that it should be providing weights for the mask channels instead of the image pixels?


Solution

  • TLDR; This is a broadcasting issue which is surprisingly not handled by PyTorch's nn.BCEWithLogitsLoss namely F.binary_cross_entropy_with_logits. It might actually be worth putting out a Github issue linking to this SO thread to notify the developers of this undesirable behaviour.

    In the documentation page of nn.BCEWithLogitsLoss, it is stated that the provided positive weights tensor pos_weight:

    Must be a vector with length equal to the number of classes.

    This is of course what you were expecting (rightly so) since positive weights refer to the weight given to the positive instances for every single class. Since your prediction and target tensors are multi-dimensional this seems to not be handled properly by PyTorch.


    Anyhows, here is a minimal example showing how you can bypass this error and also showing the manual computation of the binary cross-entropy, as reference.

    Here is the setup of the prediction and target tensors pred and label respectively:

    >>> c=2;b=5;h=3;w=3
    >>> pred = torch.rand(b,c,h,w)
    >>> label = torch.randint(0,2, (b,c,h,w), dtype=float)
    

    Now for the definition of the positive weight, notice the leading singletons dimensions:

    >>> pos_weight = torch.rand(c,1,1) 
    

    In your case, with your existing 1D tensor of length c, you would simply have to unsqueeze two extra dimensions for the height and width dimensions. This means doing something like: pos_weight = pos_weight[:,None,None].

    Calling the bce with logits function or its oop equivalent:

    >>> F.binary_cross_entropy_with_logits(pred, label, pos_weight=pos_weight).mean()
    

    Which is equivalent, in plain code to:

    >>> z = torch.sigmoid(pred)
    >>> bce = -(pos_weight*label*torch.log(z) + (1-label)*torch.log(1-z))
    

    Note, that the built-in function would have the desired behaviour (i.e. no error message) if the class dimension was last in your prediction and target tensors.

    >>> pos_weight = torch.rand(c)
    >>> F.binary_cross_entropy_with_logits(
    ...    pred.transpose(1,-1), 
    ...    label.transpose(1,-1), 
    ...    pos_weight=pos_weight)
    

    In other words, we are applying the function with format NHWC which means the pos_weight of format C can be multiplied properly. So the result above effectively yields the same result as:

    >>> F.binary_cross_entropy_with_logits(
    ...    pred, 
    ...    label, 
    ...    pos_weight=pos_weight[:,None,None])
    

    You can read more about the pos_weight in BCEWithLogitsLoss in another thread here