Search code examples
pythontensorflowmachine-learningkerasattention-model

Context vector shape using Bahdanau Attention


I am looking here at the Bahdanau attention class. I noticed that the final shape of the context vector is (batch_size, hidden_size). I am wondering how they got that shape given that attention_weights has shape (batch_size, 64, 1) and features has shape (batch_size, 64, embedding_dim). They multiplied the two (I believe it is a matrix product) and then summed up over the first axis. Where is the hidden size coming from in the context vector?


Solution

  • The context vector resulting from Bahdanau attention is a weighted average of all the hidden states of the encoder. The following image from Ref shows how this is calculated. Essentially we do the following.

    1. Compute attention weights, which is a (batch size, encoder time steps, 1) sized tensor
    2. Multiply each hidden state (batch size, hidden size) element-wise with e values. Resulting in (batch_size, encoder timesteps, hidden size)
    3. Average over the time dimension, resulting in (batch size, hidden size)

    enter image description here