Search code examples
pythonmachine-learningpytorchtransformer-modelattention-model

Query padding mask and key padding mask in Transformer encoder


I'm implementing self-attention part in transformer encoder using pytorch nn.MultiheadAttention and confusing in the padding masking of transformer.

The following picture shows the self-attention weight of the query (row) and key (column).

As you can see, there are some tokens "<PAD>" and I have already mask it in key. Therefore the tokens will not calculate the attention weight.

enter image description here

There are still two questions:

  1. In query part, can I also mask them("<PAD>") except for the red square part? Is this reasonable?

  2. How can I mask "<PAD>" in the query?

The attention weights also use the softmax function along the row by giving mask in src_mask or src_key_padding_mask argument. If I set all the "<PAD>" row into -inf, the softmax will return nan and the loss with be nan


Solution

  • There is no need to mask the queries during self-attention, it should be enough if do not use the states corresponding to the <PAD> tokens later in the network (either as hidden states or keys/values), they will not influence the loss function nor anything else in the network.

    If you want to make sure that you did not make a bug causing the gradient flowing through the <PAD> tokens you can explicitly zero-out the self-attention using torch.where after it is computed.