Search code examples
neural-networknlppytorchbert-language-modelattention-model

Multi Head Attention: Correct implementation of Linear Transformations of Q, K, V


I am implementing the Multi-Head Self-Attention in Pytorch now. I looked at a couple of implementations and they seem a bit wrong, or at least I am not sure why it is done the way it is. They would often apply the linear projection just once:

    self.query_projection = nn.Linear(input_dim, output_dim)
    self.key_projection = nn.Linear(input_dim, output_dim)
    self.value_projection = nn.Linear(input_dim, output_dim)

and then they would often reshape the projection as

    query_heads = query_projected.view(batch_size, query_lenght, head_count, head_dimension).transpose(1,2)
    key_heads = key_projected.view(batch_size, key_len, head_count, head_dimension).transpose(1, 2)  # (batch_size, heads_count, key_len, d_head)
    value_heads = value_projected.view(batch_size, value_len, head_count, head_dimension).transpose(1, 2)  # (batch_size, heads_count, value_len, d_head)

    attention_weights = scaled_dot_product(query_heads, key_heads) 

According to this code, each head will work on a piece of a projected query. However, the initial paper says that we need to have a different Linear projection for each head in the encoder.

Is this displayed implementation correct?


Solution

  • They are equivalent.

    Theoretically (and in paper writing), it is easier to consider them as separate linear projections. Say if you have 8 heads, and each head has a M->N projection, then one would have 8 N by M matrix.

    In implementation though, it is faster to have a M->8N transformation by having a 8N by M matrix.

    One can concatenate the matrices in the first formulation to obtain the matrix in the second formulation.