The code here of the Tensorflow translate.py example confused me. The copied code is:
# This is a greedy decoder - outputs are just argmaxes of output_logits.
outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
Why does the argmax
work?
The output_logits
's shape is [bucket_length,batch_size,embedding_size]
For each logit (or: activation for each word) they take the index where the activation has the highest value of everything.
For the argmax: take a look at the numpy examples on this page: https://docs.scipy.org/doc/numpy/reference/generated/numpy.argmax.html
a = array([[0, 1, 2],
[3, 4, 5]])
>>> np.argmax(a)
5
>>> np.argmax(a, axis=0)
array([1, 1, 1])
>>> np.argmax(a, axis=1)
array([2, 2])
So what output does is:
You should look at the shape of the resulting outputs array. You will see that because batch_size is 1 it all works out!
Let me know if this helps you!