Search code examples
python-3.xpytorchtransformer-modelattention-model

Why embed dimemsion must be divisible by num of heads in MultiheadAttention?


I am learning the Transformer. Here is the pytorch document for MultiheadAttention. In their implementation, I saw there is a constraint:

 assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"

Why require the constraint: embed_dim must be divisible by num_heads? If we go back to the equation

MultiHead(Q,K,V)=Concat(head1​,…,headh​)WOwhereheadi​=Attention(QWiQ​,KWiK​,VWiV​)

Assume: Q, K,V are n x emded_dim matrices; all the weight matrices W is emded_dim x head_dim,

Then, the concat [head_i, ..., head_h] will be a n x (num_heads*head_dim) matrix;

W^O with size (num_heads*head_dim) x embed_dim

[head_i, ..., head_h] * W^O will become a n x embed_dim output

I don't know why we require embed_dim must be divisible by num_heads.

Let say we have num_heads=10000, the resuts are the same, since the matrix-matrix product will absort this information.


Solution

  • From what I understood, it is a simplification they have added to keep things simple. Theoretically, we can implement the model like you proposed (similar to the original paper). In pytorch documention, they have briefly mentioned it.

    Note that `embed_dim` will be split across `num_heads` (i.e. each head will have dimension `embed_dim` // `num_heads`)
    

    Also, if you see the Pytorch implementation, you can see it is a bit different (optimised in my point of view) when comparing to the originally proposed model. For example, they use MatMul instead of Linear and Concat layer is ignored. Refer the below which shows the first encoder (with Btach size 32, 10 words, 512 features).

    enter image description here

    P.s: If you need to see the model params (like the above image), this is the code I used.

    import torch
    transformer_model = torch.nn.Transformer(d_model=512, nhead=8, num_encoder_layers=1,num_decoder_layers=1,dim_feedforward=11)  # change params as necessary
    tgt = torch.rand((20, 32, 512))
    src = torch.rand((11, 32, 512))
    torch.onnx.export(transformer_model, (src, tgt), "transformer_model.onnx")