Search code examples
tensorflowneural-networkgradient-descentreinforcement-learning

When we do supervised classification with NN, why do we train for cross-entropy and not for classification error?


The standard supervised classification setup: we have a bunch of samples, each with the correct label out of N labels. We build a NN with N outputs, transform those to probabilities with softmax, and the loss is the mean cross-entropy between each NN output and the corresponding true label, represented as a 1-hot vector with 1 in the true label and 0 elsewhere. We then optimize this loss by following its gradient. The classification error is used just to measure our model quality.

HOWEVER, I know that when doing policy gradient we can use the likelihood ratio trick, and we no longer need to use cross-entropy! our loss simply tf.gather the NN output corresponding to the correct label. E.g. this solution of OpenAI gym CartPole.

WHY can't we use the same trick when doing supervised learning? I was thinking that the reason we used cross-entropy is because it is differentiable, but apparently tf.gather is differentiable as well.

I mean - IF we measure ourselves on classification error, and we CAN optimize for classification error as it's differentiable, isn't it BETTER to also optimize for classification error instead of this weird cross-entropy proxy?


Solution

  • Policy gradient is using cross entropy (or KL divergence, as Ishant pointed out). For supervised learning tf.gather is really just implementational trick, nothing else. For RL on the other hand it is a must because you do not know "what would happen" if you would execute other action. Consequently you end up with high variance estimator of your gradients, something that you would like to avoid for all costs, if possible.

    Going back to supervised learning though

    CE(p||q) = - SUM_i q_i log p_i
    

    Lets assume that q_i is one hot encoded, with 1 at k'th position, then

    CE(p||q) = - q_k log p_k = - log p_k
    

    So if you want, you can implement this as tf.gather, it simply does not matter. The cross-entropy is simply more generic because it handles more complex targets. In particular, in TF you have sparse cross entropy which does exactly what you describe - exploits one hot encoding, that's it. Mathematically there is no difference, there is small difference computation-wise, and there are functions doing exactly what you want.