Search code examples
machine-learningtensorflowneural-network

How to get class labels from TensorFlow prediction


I have a classification model in TF and can get a list of probabilities for the next class (preds). Now I want to select the highest element (argmax) and display its class label.

This may seems silly, but how can I get the class label that matches a position in the predictions tensor?

        feed_dict={g['x']: current_char}
        preds, state = sess.run([g['preds'],g['final_state']], feed_dict)
        prediction = tf.argmax(preds, 1)

preds gives me a vector of predictions for each class. Surely there must be an easy way to just output the most likely class (label)?

Some info about my model:

x = tf.placeholder(tf.int32, [None, num_steps], name='input_placeholder')
y = tf.placeholder(tf.int32, [None, 1], name='labels_placeholder')
batch_size = batch_size = tf.shape(x)[0]  
x_one_hot = tf.one_hot(x, num_classes)
rnn_inputs = [tf.squeeze(i, squeeze_dims=[1]) for i in
              tf.split(x_one_hot, num_steps, 1)] 

tmp = tf.stack(rnn_inputs)
print(tmp.get_shape())
tmp2 = tf.transpose(tmp, perm=[1, 0, 2])
print(tmp2.get_shape())
rnn_inputs = tmp2


with tf.variable_scope('softmax'):
    W = tf.get_variable('W', [state_size, num_classes])
    b = tf.get_variable('b', [num_classes], initializer=tf.constant_initializer(0.0))


rnn_outputs = rnn_outputs[:, num_steps - 1, :]
rnn_outputs = tf.reshape(rnn_outputs, [-1, state_size])
y_reshaped = tf.reshape(y, [-1])
logits = tf.matmul(rnn_outputs, W) + b
predictions = tf.nn.softmax(logits)

Solution

  • A prediction is an array of n types of classes(labels). It represents the model's "confidence" that the image corresponds to each of its classes(labels). You can check which label has the highest confidence value by using:

    prediction = np.argmax(preds, 1)
    

    After getting this highest element index using (argmax function) out of other probabilities, you need to place this index into class labels to find the exact class name associated with this index.

    class_names[prediction]
    

    Please refer to this link for more understanding.