Search code examples
pythontensorflowsoftmax

How to get two or more maximum indexes values set to 1 from tf.softmax's output


I want to get the maximum (2 or more) indexes set to 1 from the output of tf.nn.softmax(). given tf.nn.softmax's outputs as [0.1, 0.4, 0.2, 0.1, 0.8] I want to get something like [0,1,0,0,1] since those indexes have the maximum numbers (in this case I chose just the maximum 2). Thank you in advance!


Solution

  • You can use tf.nn.top_k that returns the highest values, together with their position, of the input vector.

    probs = tf.nm.softmax(logits)
    k = 2 # the first k=2 highest values
    indices, values = tf.nn.top_k(probs, k=k)