when I write the tensorflow demo, I find this arg_max()
function in the definition of correct_predition
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
logits=hypothesis,labels=Y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
correct_prediction = tf.equal(tf.arg_max(hypothesis,1),tf.arg_max(Y,1))
Returns the index with the largest value across axes of a tensor.(this is from the API from TF)
for we use 'softmax_cross_entropy_with_logits
',
the prediction (hypothesis
) are presented as the probability
we will get the index of largest probability of prediction by the arg_max()
function for hypothesis
but Y
is labels, it is not the probability why we need use tf.arg_max(Y,1)
?
arg_max(hypothesis) is returning an INDEX. Y is a length 10 one-hot vector. tf.equal() can't do anything sensible here because the two things are not compatible.
So, arg_max(Y) returns an INDEX. Now tf.equal() can do a sensible thing: 1 if the prediction matches the target, 0 otherwise.
Note the arg_max() is not a function about probabilities: it's just a function to return the index of the biggest element.