Search code examples
machine-learningnlpdeep-learningsequence-to-sequenceattention-model

What does the "source hidden state" refer to in the Attention Mechanism?


The attention weights are computed as:

enter image description here

I want to know what the h_s refers to.

In the tensorflow code, the encoder RNN returns a tuple:

encoder_outputs, encoder_state = tf.nn.dynamic_rnn(...)

As I think, the h_s should be the encoder_state, but the github/nmt gives a different answer?

# attention_states: [batch_size, max_time, num_units]
attention_states = tf.transpose(encoder_outputs, [1, 0, 2])

# Create an attention mechanism
attention_mechanism = tf.contrib.seq2seq.LuongAttention(
    num_units, attention_states,
    memory_sequence_length=source_sequence_length)

Did I misunderstand the code? Or the h_s actually means the encoder_outputs?


Solution

  • The formula is probably from this post, so I'll use a NN picture from the same post:

    nn

    Here, the h-bar(s) are all the blue hidden states from the encoder (the last layer), and h(t) is the current red hidden state from the decoder (also the last layer). One the picture t=0, and you can see which blocks are wired to the attention weights with dotted arrows. The score function is usually one of those:

    formula


    Tensorflow attention mechanism matches this picture. In theory, cell output is in most cases its hidden state (one exception is LSTM cell, in which the output is the short-term part of the state, and even in this case the output suits better for attention mechanism). In practice, tensorflow's encoder_state is different from encoder_outputs when the input is padded with zeros: the state is propagated from the previous cell state while the output is zero. Obviously, you don't want to attend to trailing zeros, so it makes sense to have h-bar(s) for these cells.

    So encoder_outputs are exactly the arrows that go from the blue blocks upward. Later in a code, attention_mechanism is connected to each decoder_cell, so that its output goes through the context vector to the yellow block on the picture.

    decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
        decoder_cell, attention_mechanism,
        attention_layer_size=num_units)