Search code examples
pythondeep-learningpytorchtransformer-modelattention-model

what the difference between att_mask and key_padding_mask in MultiHeadAttnetion


What the difference between att_mask and key_padding_mask in MultiHeadAttnetion of pytorch:

key_padding_mask – if provided, specified padding elements in the key will be ignored by the attention. When given a binary mask and a value is True, the corresponding value on the attention layer will be ignored. When given a byte mask and a value is non-zero, the corresponding value on the attention layer will be ignored

attn_mask – 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all the batches while a 3D mask allows to specify a different mask for the entries of each batch.

Thanks in advance.


Solution

  • The key_padding_mask is used to mask out positions that are padding, i.e., after the end of the input sequence. This is always specific to the input batch and depends on how long are the sequence in the batch compared to the longest one. It is a 2D tensor of shape batch size × input length.

    On the other hand, attn_mask says what key-value pairs are valid. In a Transformer decoder, a triangle mask is used to simulate the inference time and prevent the attending to the "future" positions. This is what att_mask is usually used for. If it is a 2D tensor, the shape is input length × input length. You can also have a mask that is specific to every item in a batch. In that case, you can use a 3D tensor of shape (batch size × num heads) × input length × input length. (So, in theory, you can simulate key_padding_mask with a 3D att_mask.)