I'm currently trying to implement an LSTM with attention in PyTorch, and as soon as it comes to dealing with batch sizes and multidimensional tensors I suddenly forget how linear algebra works. I have a tensor of attention scores of size [64, 19, 1], where 64 is the batch size and 19 is the max length of a source sentence. I also have a tensor of outputs of the encoder (hidden states); its shape is [64, 19, 256], where 256 is the dimension of the hidden state. What's a decent way to compute the context vector/the weighted average of attention scores and encoder hidden states? My brain is unfortunately unable to reason about these sorts of things when it comes to more than two dimensions/including the batch sizes.
As simple as context = torch.sum(attention * encoder_hidden, dim=1)