Search code examples
deep-learningpytorchseq2seq

Why do we do batch matrix-matrix product?


I'm following Pytorch seq2seq tutorial and ittorch.bmm method is used like below:

attn_applied = torch.bmm(attn_weights.unsqueeze(0),
                         encoder_outputs.unsqueeze(0))

I understand why we need to multiply attention weight and encoder outputs.

What I don't quite understand is the reason why we need bmm method here. torch.bmm document says

Performs a batch matrix-matrix product of matrices stored in batch1 and batch2.

batch1 and batch2 must be 3-D tensors each containing the same number of matrices.

If batch1 is a (b×n×m) tensor, batch2 is a (b×m×p) tensor, out will be a (b×n×p) tensor.

enter image description here


Solution

  • In the seq2seq model, the encoder encodes the input sequences given in as mini-batches. Say for example, the input is B x S x d where B is the batch size, S is the maximum sequence length and d is the word embedding dimension. Then the encoder's output is B x S x h where h is the hidden state size of the encoder (which is an RNN).

    Now while decoding (during training) the input sequences are given one at a time, so the input is B x 1 x d and the decoder produces a tensor of shape B x 1 x h. Now to compute the context vector, we need to compare this decoder hidden state with the encoder's encoded states.

    So, consider you have two tensors of shape T1 = B x S x h and T2 = B x 1 x h. So if you can do batch matrix multiplication as follows.

    out = torch.bmm(T1, T2.transpose(1, 2))
    

    Essentially you are multiplying a tensor of shape B x S x h with a tensor of shape B x h x 1 and it will result in B x S x 1 which is the attention weight for each batch.

    Here, the attention weights B x S x 1 represent a similarity score between the decoder's current hidden state and encoder's all the hidden states. Now you can take the attention weights to multiply with the encoder's hidden state B x S x h by transposing first and it will result in a tensor of shape B x h x 1. And if you perform squeeze at dim=2, you will get a tensor of shape B x h which is your context vector.

    This context vector (B x h) is usually concatenated to decoder's hidden state (B x 1 x h, squeeze dim=1) to predict the next token.