Search code examples
pythonmachine-learningtensorflowcross-entropy

Calculating Cross Entropy in TensorFlow


I am having a hard time with calculating cross entropy in tensorflow. In particular, I am using the function:

tf.nn.softmax_cross_entropy_with_logits()

Using what is seemingly simple code, I can only get it to return a zero

import tensorflow as tf
import numpy as np

sess = tf.InteractiveSession()

a = tf.placeholder(tf.float32, shape =[None, 1])
b = tf.placeholder(tf.float32, shape = [None, 1])
sess.run(tf.global_variables_initializer())
c = tf.nn.softmax_cross_entropy_with_logits(
    logits=b, labels=a
).eval(feed_dict={b:np.array([[0.45]]), a:np.array([[0.2]])})
print c

returns

0

My understanding of cross entropy is as follows:

H(p,q) = p(x)*log(q(x))

Where p(x) is the true probability of event x and q(x) is the predicted probability of event x.

There if input any two numbers for p(x) and q(x) are used such that

0<p(x)<1 AND 0<q(x)<1

there should be a nonzero cross entropy. I am expecting that I am using tensorflow incorrectly. Thanks in advance for any help.


Solution

  • Like they say, you can't spell "softmax_cross_entropy_with_logits" without "softmax". Softmax of [0.45] is [1], and log(1) is 0.

    Measures the probability error in discrete classification tasks in which the classes are mutually exclusive (each entry is in exactly one class). For example, each CIFAR-10 image is labeled with one and only one label: an image can be a dog or a truck, but not both.

    NOTE: While the classes are mutually exclusive, their probabilities need not be. All that is required is that each row of labels is a valid probability distribution. If they are not, the computation of the gradient will be incorrect.

    If using exclusive labels (wherein one and only one class is true at a time), see sparse_softmax_cross_entropy_with_logits.

    WARNING: This op expects unscaled logits, since it performs a softmax on logits internally for efficiency. Do not call this op with the output of softmax, as it will produce incorrect results.

    logits and labels must have the same shape [batch_size, num_classes] and the same dtype (either float16, float32, or float64).