Search code examples
pytorchlstmrecurrent-neural-networktensorattention-model

How do I compute the weighted average of attention scores and encoder outputs in PyTorch?


I'm currently trying to implement an LSTM with attention in PyTorch, and as soon as it comes to dealing with batch sizes and multidimensional tensors I suddenly forget how linear algebra works. I have a tensor of attention scores of size [64, 19, 1], where 64 is the batch size and 19 is the max length of a source sentence. I also have a tensor of outputs of the encoder (hidden states); its shape is [64, 19, 256], where 256 is the dimension of the hidden state. What's a decent way to compute the context vector/the weighted average of attention scores and encoder hidden states? My brain is unfortunately unable to reason about these sorts of things when it comes to more than two dimensions/including the batch sizes.


Solution

  • As simple as context = torch.sum(attention * encoder_hidden, dim=1).