Search code examples
tensorflowneural-networknlpdeep-learningmachine-translation

How to decode the output of seq2seq?


The code here of the Tensorflow translate.py example confused me. The copied code is:

  # This is a greedy decoder - outputs are just argmaxes of output_logits.
  outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]

Why does the argmax work?

The output_logits's shape is [bucket_length,batch_size,embedding_size]


Solution

  • For each logit (or: activation for each word) they take the index where the activation has the highest value of everything.

    For the argmax: take a look at the numpy examples on this page: https://docs.scipy.org/doc/numpy/reference/generated/numpy.argmax.html

    a = array([[0, 1, 2],
           [3, 4, 5]])
    >>> np.argmax(a)
    5
    >>> np.argmax(a, axis=0)
    array([1, 1, 1])
    >>> np.argmax(a, axis=1)
    array([2, 2])
    

    So what output does is:

    • For each word (the length of bucket_length)
      • get the max activation of the embedding_size

    You should look at the shape of the resulting outputs array. You will see that because batch_size is 1 it all works out!

    Let me know if this helps you!