I am implementing the Multi-Head Self-Attention in Pytorch now. I looked at a couple of implementations and they seem a bit wrong, or at least I am not sure why it is done the way it is. They would often apply the linear projection just once:
self.query_projection = nn.Linear(input_dim, output_dim)
self.key_projection = nn.Linear(input_dim, output_dim)
self.value_projection = nn.Linear(input_dim, output_dim)
and then they would often reshape the projection as
query_heads = query_projected.view(batch_size, query_lenght, head_count, head_dimension).transpose(1,2)
key_heads = key_projected.view(batch_size, key_len, head_count, head_dimension).transpose(1, 2) # (batch_size, heads_count, key_len, d_head)
value_heads = value_projected.view(batch_size, value_len, head_count, head_dimension).transpose(1, 2) # (batch_size, heads_count, value_len, d_head)
attention_weights = scaled_dot_product(query_heads, key_heads)
According to this code, each head will work on a piece of a projected query. However, the initial paper says that we need to have a different Linear projection for each head in the encoder.
Is this displayed implementation correct?
They are equivalent.
Theoretically (and in paper writing), it is easier to consider them as separate linear projections. Say if you have 8 heads, and each head has a M->N
projection, then one would have 8
N by M
matrix.
In implementation though, it is faster to have a M->8N
transformation by having a 8N by M
matrix.
One can concatenate the matrices in the first formulation to obtain the matrix in the second formulation.