I want to get the maximum (2 or more) indexes set to 1 from the output of tf.nn.softmax()
.
given tf.nn.softmax
's outputs as [0.1, 0.4, 0.2, 0.1, 0.8]
I want to get something like [0,1,0,0,1]
since those indexes have the maximum numbers (in this case I chose just the maximum 2). Thank you in advance!
You can use tf.nn.top_k
that returns the highest values, together with their position, of the input vector.
probs = tf.nm.softmax(logits)
k = 2 # the first k=2 highest values
indices, values = tf.nn.top_k(probs, k=k)