Search code examples
pythonmachine-learningneural-networkloss-function

Understanding Cross Entropy Loss


I see a lot of explanations about CEL or binary cross entropy loss in the context where the ground truth is say, a 0 or 1, and then you get a function like:

def CrossEntropy(yHat, y):
    if yHat == 1:
      return -log(y)
    else:
      return -log(1 - y)

However, I'm confused at how BCE works when your yHat is not a discrete 0 or 1. For example if I want to look at reconstruction loss for an MNIST digit where my ground truths are 0 < yHat < 1, and my predictions are also in the same range, how does this change my function?

EDIT:

Apologies let me give some more context for my confusion. In the PyTorch tutorials on VAEs they use BCE to calculate reconstruction loss, where yhat (as far as I understand, is not discrete). See:

https://github.com/pytorch/examples/blob/master/vae/main.py

The implementation works...but I don't understand how that BCE loss is calculated in this case.


Solution

  • Cross entropy measures distance between any two probability distributions. In what you describe (the VAE), MNIST image pixels are interpreted as probabilities for pixels being "on/off". In that case your target probability distribution is simply not a dirac distribution (0 or 1) but can have different values. See the cross entropy definition on Wikipedia.

    With the above as a reference, let's say your model outputs a reconstruction for a certain pixel of 0.7. This essentially says that your model estimates p(pixel=1) = 0.7, and accordingly p(pixel=0) = 0.3.
    If the target pixels would just be 0 or 1, your cross entropy for this pixel would either be -log(0.3) if the true pixel is 0 or -log(0.7) (a smaller value) if the true pixel is 1.
    The full formula would be -(0*log(0.3) + 1*log(0.7)) if the true pixel is 1 or -(1*log(0.3) + 1*log(0.7)) otherwise.

    Let's say your target pixel is actually 0.6! This essentially says that the pixel has a probability of 0.6 to be on and 0.4 to be off.
    This simply changes the cross entropy computation to -(0.4*log(0.3) + 0.6*log(0.7)).

    Finally, you can simply average/sum these per-pixel cross-entropies over the image.