The standard supervised classification setup: we have a bunch of samples, each with the correct label out of N
labels. We build a NN with N outputs, transform those to probabilities with softmax, and the loss is the mean cross-entropy
between each NN output and the corresponding true label, represented as a 1-hot
vector with 1
in the true label and 0
elsewhere. We then optimize this loss by following its gradient. The classification error is used just to measure our model quality.
HOWEVER, I know that when doing policy gradient
we can use the likelihood ratio trick, and we no longer need to use cross-entropy
! our loss simply tf.gather
the NN output corresponding to the correct label. E.g. this solution of OpenAI gym CartPole.
WHY can't we use the same trick when doing supervised learning? I was thinking that the reason we used cross-entropy
is because it is differentiable, but apparently tf.gather
is differentiable as well.
I mean - IF we measure ourselves on classification error, and we CAN optimize for classification error as it's differentiable, isn't it BETTER to also optimize for classification error instead of this weird cross-entropy
proxy?
Policy gradient is using cross entropy (or KL divergence, as Ishant pointed out). For supervised learning tf.gather is really just implementational trick, nothing else. For RL on the other hand it is a must because you do not know "what would happen" if you would execute other action. Consequently you end up with high variance estimator of your gradients, something that you would like to avoid for all costs, if possible.
Going back to supervised learning though
CE(p||q) = - SUM_i q_i log p_i
Lets assume that q_i is one hot encoded, with 1 at k'th position, then
CE(p||q) = - q_k log p_k = - log p_k
So if you want, you can implement this as tf.gather, it simply does not matter. The cross-entropy is simply more generic because it handles more complex targets. In particular, in TF you have sparse cross entropy which does exactly what you describe - exploits one hot encoding, that's it. Mathematically there is no difference, there is small difference computation-wise, and there are functions doing exactly what you want.