Search code examples
labelsoftmaxhypothesis-testargmax

why we need tf.arg_max(Y,1) with softmax in tensorflow?


when I write the tensorflow demo, I find this arg_max() function in the definition of correct_predition

cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
            logits=hypothesis,labels=Y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)

correct_prediction = tf.equal(tf.arg_max(hypothesis,1),tf.arg_max(Y,1))

Returns the index with the largest value across axes of a tensor.(this is from the API from TF)

for we use 'softmax_cross_entropy_with_logits',

the prediction (hypothesis) are presented as the probability

we will get the index of largest probability of prediction by the arg_max() function for hypothesis

but Y is labels, it is not the probability why we need use tf.arg_max(Y,1)?


Solution

  • arg_max(hypothesis) is returning an INDEX. Y is a length 10 one-hot vector. tf.equal() can't do anything sensible here because the two things are not compatible.

    So, arg_max(Y) returns an INDEX. Now tf.equal() can do a sensible thing: 1 if the prediction matches the target, 0 otherwise.

    Note the arg_max() is not a function about probabilities: it's just a function to return the index of the biggest element.