Search code examples
deep-learningnlppytorchtorchattention-model

Why W_q matrix in torch.nn.MultiheadAttention is quadratic


I am trying to implement nn.MultiheadAttention in my network. According to the docs,

embed_dim  – total dimension of the model.

However, according to the source file,

embed_dim must be divisible by num_heads

and

self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))

If I understand properly, this means each head takes only a part of features of each query, as the matrix is quadratic. Is it a bug of realization or is my understanding wrong?


Solution

  • Each head uses a different part of the projected query vector. You can imagine it as if the query gets split into num_heads vectors that are independently used to compute the scaled dot-product attention. So, each head operates on a different linear combination of the features in queries (and keys and values, too). This linear projection is done using the self.q_proj_weight matrix and the projected queries are passed to F.multi_head_attention_forward function.

    In F.multi_head_attention_forward, it is implemented by reshaping and transposing the query vector, so that the independent attentions for individual heads can be computed efficiently by matrix multiplication.

    The attention head sizes are a design decision of PyTorch. In theory, you could have a different head size, so the projection matrix would have a shape of embedding_dim × num_heads * head_dims. Some implementations of transformers (such as C++-based Marian for machine translation, or Huggingface's Transformers) allow that.