Search code examples
deep-learningprobabilityconv-neural-networkloss-functioncross-entropy

Is it possible to implement a loss function that prioritizes the correct answer being in the top k probabilities?


I am working on an multi-class image recognition problem. The task is to have the correct answer being in the top 3 output probabilities. So I was thinking that maybe there exists a clever cost function that prioritizes the correct answer being in the top K and doesn't penalize much in between these top K.


Solution

  • This can be achieved by class-weighted cross-entropy loss, which essentially assigns the weight to the errors associated with each class. This loss is used in research, e.g. see the paper "Multi-task learning and Weighted Cross-entropy for DNN-based Keyword" by S. Panchapagesan at al. Before computing the cross-entropy, you can check if the predicted distribution satisfies your condition (e.g., ground truth class is in top-k of the predicted classes) and assign the zero (or near zero) weights accordingly, if it does.

    There are open questions though: when the correct class is in top-k, should you penalize the k-1 incorrectly predicted classes? What if, for example, the prediction is (0.9, 0.05, 0.01, ...), the third class is correct and it is in top-3 -- is this prediction good enough or not? Should you care what exactly k-1 incorrect classes are?

    All these question arise because this kind of loss doesn't have probabilistic interpretation, unlike standard cross-entropy. That's why I wouldn't recommend using it in practice, but reformulate the goal instead.

    E.g., if the original problem is that for some inputs several classes are equally good, the best way to deal with it is to use soft labels, e.g. (0.33, 0.33, 0.33, 0, 0, 0, ...) instead of one-hot (note that this totally agrees with probabilistic interpretation). It will force the network to learn features associated with all three good classes, and generally lead to the same goal, but with better control over target classes.