Search code examples
tensorflowpytorchclassificationmultilabel-classification

Loss function for Multi-Label Classification of multiple labels and sparse outcomes


Consider the Multi-Label Classification with ANN where the targeted labels were of the form

[0,0,0,0,0,1,0,1,0,0,0,1]
[0,0,0,1,0,0,0,0,0,0,0,0]
...

There were N labels each of True(1) or False(0) represented in a N length vector.

I encountered the issue with the loss function when training this network. Because, since the length of the vector N large compare to the number of the True values(the Multi-Labels are "sparse"), the ANN network can just constantly output vectors of 0s [0,0,0,0,0,0,0,0,0,0,0,0] and still achieve over 90% accuracy, since most of the labels are correctly predicted as 0.

I tried Binary Cross-Entropy (BCE) and Categorical Cross-Entropy (CCE) in pytorch and tensorflow, but did not get any improvements.

I thought about to write something myself to increase the weight over the True(1) values, but I suspect that would just flip the results and make everything to be [1,1,1,1,1,1,1,1,1,1,1,1]?

What's the appropriate loss function for Multi-Label Classification of multiple labels and sparse outcomes? (Example in pytorch and tensorflow)

A related post could be found almost 10 years ago: Multilabel image classification with sparse labels in TensorFlow?


Solution

  • thought about to write something myself to increase the weight over the True(1) values, but I suspect that would just flip the results and make everything to be [1,1,1,1,1,1,1,1,1,1,1,1]?

    No, it wouldn't. Largest contributor of the loss are 0 values if the vector is really sparse. Neural network will be biased to output 0 as it is way more common

    What's the appropriate loss function for Multi-Label Classification of multiple labels and sparse outcomes?

    For pytorch BCEWithLogitsLoss (in case your network outputs logits [output of the last layer without sigmoid applied], otherwise BCELoss)

    An example implementation could be:

    import typing
    import torch
    
    
    class WeightBCEWithLogitsLoss(torch.nn.Module):
        def __init__(
            self,
            weight: float,
            reduction: typing.Callable[[torch.Tensor], torch.Tensor] | None = None,
        ):
            super().__init__()
            self.weight = weight
            if reduction is None:
                reduction = torch.mean
            self.reduction = reduction
    
        def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
            loss_matrix = torch.nn.functional.binary_cross_entropy_with_logits(
                input,
                target,
                reduction="none",
            )
            loss_matrix[target] *= self.weight
            return self.reduction(loss_matrix)
    

    and usage:

    # 7x more focus on the `training` positive samples
    criterion = WeightBCEWithLogitsLoss(weight=7.)