Search code examples
pytorchtensortransformer-modelattention-modelhuggingface-transformers

Why is the input size of the MultiheadAttention in Pytorch Transformer module 1536?


When using the torch.nn.modules.transformer.Transformer module/object, the first layer is the encoder.layers.0.self_attn layer that is a MultiheadAttention layer, i.e.

from torch.nn.modules.transformer import Transformer
bumblebee = Transformer()

bumblee.parameters

[out]:

<bound method Module.parameters of Transformer(
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): Linear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )

And if we print out the size of the layer, we see:

for name in bumblebee.encoder.state_dict():
    print(name, '\t', bumblebee.encoder.state_dict()[name].shape)

[out]:

layers.0.self_attn.in_proj_weight    torch.Size([1536, 512])
layers.0.self_attn.in_proj_bias      torch.Size([1536])
layers.0.self_attn.out_proj.weight   torch.Size([512, 512])
layers.0.self_attn.out_proj.bias     torch.Size([512])
layers.0.linear1.weight      torch.Size([2048, 512])
layers.0.linear1.bias    torch.Size([2048])
layers.0.linear2.weight      torch.Size([512, 2048])
layers.0.linear2.bias    torch.Size([512])
layers.0.norm1.weight    torch.Size([512])
layers.0.norm1.bias      torch.Size([512])
layers.0.norm2.weight    torch.Size([512])
layers.0.norm2.bias      torch.Size([512])

It seems like 1536 is 512 * 3 and somehow the layers.0.self_attn.in_proj_weight parameter might be storing all three QKV tensors in the transformer architecture in one matrix.

From https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/activation.py#L649

class MultiheadAttention(Module):
    def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None):
        super(MultiheadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.kdim = kdim if kdim is not None else embed_dim
        self.vdim = vdim if vdim is not None else embed_dim
        self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim

        self.num_heads = num_heads
        self.dropout = dropout
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"

        if self._qkv_same_embed_dim is False:
            self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
            self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
            self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
        else:
            self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))

And the note in the docstring of the MultiheadAttention says:

Note: if kdim and vdim are None, they will be set to embed_dim such that query, key, and value have the same number of features.

Is that correct?


Solution

  • From the nn.Transformer definition with the default values, EncoderLayer is instantiated with d_model=512, nhead=8.

    The MultiheadAttention is instantiated with d_model, nhead equal to those values and k_dim, v_dim are left to the default value of None.

    If they are None, self._qkv_same_embed_dim at this line evaluates to True. When that happens, as you correctly pointed out self.in_proj_weight is defined as a Tensor of shape (3 x embed_dim, embed_dim).

    In short: yes, that's correct.