Consider the Multi-Label Classification with ANN where the targeted labels were of the form
[0,0,0,0,0,1,0,1,0,0,0,1]
[0,0,0,1,0,0,0,0,0,0,0,0]
...
There were N
labels each of True(1)
or False(0)
represented in a N
length vector.
I encountered the issue with the loss function when training this network. Because, since the length of the vector N
large compare to the number of the True
values(the Multi-Labels are "sparse"), the ANN network can just constantly output vectors of 0
s [0,0,0,0,0,0,0,0,0,0,0,0]
and still achieve over 90%
accuracy, since most of the labels are correctly predicted as 0
.
I tried Binary Cross-Entropy (BCE)
and Categorical Cross-Entropy (CCE)
in pytorch
and tensorflow
, but did not get any improvements.
I thought about to write something myself to increase the weight over the True(1)
values, but I suspect that would just flip the results and make everything to be [1,1,1,1,1,1,1,1,1,1,1,1]
?
What's the appropriate loss function for Multi-Label Classification of multiple labels and sparse outcomes? (Example in pytorch
and tensorflow
)
A related post could be found almost 10 years ago: Multilabel image classification with sparse labels in TensorFlow?
thought about to write something myself to increase the weight over the True(1) values, but I suspect that would just flip the results and make everything to be
[1,1,1,1,1,1,1,1,1,1,1,1]
?
No, it wouldn't. Largest contributor of the loss are 0
values if the vector is really sparse. Neural network will be biased to output 0
as it is way more common
What's the appropriate loss function for Multi-Label Classification of multiple labels and sparse outcomes?
For pytorch
BCEWithLogitsLoss (in case your network outputs logits [output of the last layer without sigmoid applied], otherwise BCELoss)
An example implementation could be:
import typing
import torch
class WeightBCEWithLogitsLoss(torch.nn.Module):
def __init__(
self,
weight: float,
reduction: typing.Callable[[torch.Tensor], torch.Tensor] | None = None,
):
super().__init__()
self.weight = weight
if reduction is None:
reduction = torch.mean
self.reduction = reduction
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
loss_matrix = torch.nn.functional.binary_cross_entropy_with_logits(
input,
target,
reduction="none",
)
loss_matrix[target] *= self.weight
return self.reduction(loss_matrix)
and usage:
# 7x more focus on the `training` positive samples
criterion = WeightBCEWithLogitsLoss(weight=7.)