Search code examples
pytorchtransformer-modelattention-model

TransformerEncoder with a padding mask


I'm trying to implement torch.nn.TransformerEncoder with a src_key_padding_mask not equal to none. Imagine the input is of the shape src = [20, 95] and the binary padding mask has the shape src_mask = [20, 95], 1 in the position of padded tokens and 0 for other positions. I make a transformer encoder with 8 layers, each of which contain an attention with 8 heads and hidden dimension 256:

layer=torch.nn.TransformerEncoderLayer(256, 8, 256, 0.1)
encoder=torch.nn.TransformerEncoder(layer, 6)
embed=torch.nn.Embedding(80000, 256)
src=torch.randint(0, 1000, (20, 95))
src = emb(src)
src_mask = torch.randint(0,2,(20, 95))
output =  encoder(src, src_mask)

But I get the following error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-107-31bf7ab8384b> in <module>
----> 1 output =  encoder(src, src_mask)

~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    545             result = self._slow_forward(*input, **kwargs)
    546         else:
--> 547             result = self.forward(*input, **kwargs)
    548         for hook in self._forward_hooks.values():
    549             hook_result = hook(self, input, result)

~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/transformer.py in forward(self, src, mask, src_key_padding_mask)
    165         for i in range(self.num_layers):
    166             output = self.layers[i](output, src_mask=mask,
--> 167                                     src_key_padding_mask=src_key_padding_mask)
    168 
    169         if self.norm:

~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    545             result = self._slow_forward(*input, **kwargs)
    546         else:
--> 547             result = self.forward(*input, **kwargs)
    548         for hook in self._forward_hooks.values():
    549             hook_result = hook(self, input, result)

~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/transformer.py in forward(self, src, src_mask, src_key_padding_mask)
    264         """
    265         src2 = self.self_attn(src, src, src, attn_mask=src_mask,
--> 266                               key_padding_mask=src_key_padding_mask)[0]
    267         src = src + self.dropout1(src2)
    268         src = self.norm1(src)

~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    545             result = self._slow_forward(*input, **kwargs)
    546         else:
--> 547             result = self.forward(*input, **kwargs)
    548         for hook in self._forward_hooks.values():
    549             hook_result = hook(self, input, result)

~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/activation.py in forward(self, query, key, value, key_padding_mask, need_weights, attn_mask)
    781                 training=self.training,
    782                 key_padding_mask=key_padding_mask, need_weights=need_weights,
--> 783                 attn_mask=attn_mask)
    784 
    785 

~/anaconda3/lib/python3.7/site-packages/torch/nn/functional.py in multi_head_attention_forward(query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training, key_padding_mask, need_weights, attn_mask, use_separate_proj_weight, q_proj_weight, k_proj_weight, v_proj_weight, static_k, static_v)
   3250     if attn_mask is not None:
   3251         attn_mask = attn_mask.unsqueeze(0)
-> 3252         attn_output_weights += attn_mask
   3253 
   3254     if key_padding_mask is not None:

RuntimeError: The size of tensor a (20) must match the size of tensor b (95) at non-singleton dimension 2

I was wondering if somebody could help me figure out this problem.

Thanks


Solution

  • The required shapes are shown in nn.Transformer.forward - Shape (all building blocks of the transformer refer to it). The relevant ones for the encoder are:

    • src: (S, N, E)
    • src_mask: (S, S)
    • src_key_padding_mask: (N, S)

    where S is the sequence length, N the batch size and E the embedding dimension (number of features).

    The padding mask should have shape [95, 20], not [20, 95]. This assumes that your batch size is 95 and the sequence length is 20, but if that is the other way around, you would have to transpose the src instead.

    Furthermore, when calling the encoder, you are not specifying the src_key_padding_mask, but rather the src_mask, as the signature of torch.nn.TransformerEncoder.forward is:

    forward(src, mask=None, src_key_padding_mask=None)
    

    The padding mask must be specified as the keyword argument src_key_padding_mask not as the second positional argument. And to avoid confusion, your src_mask should be renamed to src_key_padding_mask.

    src_key_padding_mask = torch.randint(0,2,(95, 20))
    output =  encoder(src, src_key_padding_mask=src_key_padding_mask)