I'm trying to implement torch.nn.TransformerEncoder with a src_key_padding_mask not equal to none. Imagine the input is of the shape src = [20, 95]
and the binary padding mask has the shape src_mask = [20, 95]
, 1 in the position of padded tokens and 0 for other positions. I make a transformer encoder with 8 layers, each of which contain an attention with 8 heads and hidden dimension 256:
layer=torch.nn.TransformerEncoderLayer(256, 8, 256, 0.1)
encoder=torch.nn.TransformerEncoder(layer, 6)
embed=torch.nn.Embedding(80000, 256)
src=torch.randint(0, 1000, (20, 95))
src = emb(src)
src_mask = torch.randint(0,2,(20, 95))
output = encoder(src, src_mask)
But I get the following error:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-107-31bf7ab8384b> in <module>
----> 1 output = encoder(src, src_mask)
~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
545 result = self._slow_forward(*input, **kwargs)
546 else:
--> 547 result = self.forward(*input, **kwargs)
548 for hook in self._forward_hooks.values():
549 hook_result = hook(self, input, result)
~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/transformer.py in forward(self, src, mask, src_key_padding_mask)
165 for i in range(self.num_layers):
166 output = self.layers[i](output, src_mask=mask,
--> 167 src_key_padding_mask=src_key_padding_mask)
168
169 if self.norm:
~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
545 result = self._slow_forward(*input, **kwargs)
546 else:
--> 547 result = self.forward(*input, **kwargs)
548 for hook in self._forward_hooks.values():
549 hook_result = hook(self, input, result)
~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/transformer.py in forward(self, src, src_mask, src_key_padding_mask)
264 """
265 src2 = self.self_attn(src, src, src, attn_mask=src_mask,
--> 266 key_padding_mask=src_key_padding_mask)[0]
267 src = src + self.dropout1(src2)
268 src = self.norm1(src)
~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
545 result = self._slow_forward(*input, **kwargs)
546 else:
--> 547 result = self.forward(*input, **kwargs)
548 for hook in self._forward_hooks.values():
549 hook_result = hook(self, input, result)
~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/activation.py in forward(self, query, key, value, key_padding_mask, need_weights, attn_mask)
781 training=self.training,
782 key_padding_mask=key_padding_mask, need_weights=need_weights,
--> 783 attn_mask=attn_mask)
784
785
~/anaconda3/lib/python3.7/site-packages/torch/nn/functional.py in multi_head_attention_forward(query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training, key_padding_mask, need_weights, attn_mask, use_separate_proj_weight, q_proj_weight, k_proj_weight, v_proj_weight, static_k, static_v)
3250 if attn_mask is not None:
3251 attn_mask = attn_mask.unsqueeze(0)
-> 3252 attn_output_weights += attn_mask
3253
3254 if key_padding_mask is not None:
RuntimeError: The size of tensor a (20) must match the size of tensor b (95) at non-singleton dimension 2
I was wondering if somebody could help me figure out this problem.
Thanks
The required shapes are shown in nn.Transformer.forward
- Shape (all building blocks of the transformer refer to it). The relevant ones for the encoder are:
where S is the sequence length, N the batch size and E the embedding dimension (number of features).
The padding mask should have shape [95, 20], not [20, 95]. This assumes that your batch size is 95 and the sequence length is 20, but if that is the other way around, you would have to transpose the src
instead.
Furthermore, when calling the encoder, you are not specifying the src_key_padding_mask
, but rather the src_mask
, as the signature of torch.nn.TransformerEncoder.forward
is:
forward(src, mask=None, src_key_padding_mask=None)
The padding mask must be specified as the keyword argument src_key_padding_mask
not as the second positional argument. And to avoid confusion, your src_mask
should be renamed to src_key_padding_mask
.
src_key_padding_mask = torch.randint(0,2,(95, 20))
output = encoder(src, src_key_padding_mask=src_key_padding_mask)