Search code examples
pytorchclassificationloss-functioncross-entropy

Multi-target loss recommendations


I'm working on a classification problem. The number of classes is 5. I have a ground truth vector that has the shape (3) instead of 1. The values in this target vector are the possible classes and the predicted vector is of the shape (1x5) which holds the softmax scores for all the classes.

For example:

predicted_vector = tensor([0.0669, 0.1336, 0.3400, 0.3392, 0.1203]
ground_truth = tensor([3,2,5])

For the above illustration, a typical argmax operation would result in declaring class 3 as the predicted class (0.34) but I want the model to reward even if the argmax class is any of 3,2, or 5.

Which loss function is recommended for such a use case?


Solution

  • For this problem, a given sample is in exactly one class (say, class 3), but for training purposes, predicting class 2 or 5 is still okay so the model isn't penalised that heavily.

    This is a typical single-label, multi-class problem, but with probabilistic (“soft”) labels, and CrossEntropyLoss should be used (and not use softmax()).

    In this example, the (soft) target might be a probability of 0.7 for class 3, a probability of 0.2 for class 2, and a probability of 0.1 for class 5 (and zero for everything else).