I have seen two different implementations of Multi-Head Attention.
def split_heads(self, x, batch_size):
return x.reshape(batch_size, -1, self.heads, self.projection_dim)
def forward(self, queries, keys, values, mask):
batch_size = queries.size()[0]
# split queries keys and values into heads
queries = self.split_heads(queries, batch_size)
keys = self.split_heads(keys, batch_size)
values = self.split_heads(values, batch_size)
queries = self.queries_linear(queries)
keys = self.keys_linear(keys)
values = self.values_linear(values)
#...more code
def forward(self, queries, keys, values, mask=None):
batch_size = q.size(0)
# perform linear operation and split into h heads
k = self.keys_linear(keys).view(batch_size, -1, self.heads, self.projection_dim)
q = self.queries_linear(queries).view(batch_size, -1, self.heads, self.projection_dim)
v = self.values_linear(values).view(batch_size, -1, self.heads, self.projection_dim)
#...more code
According to the paper Attention Is All You Need, from what I can deduce from the image the queries and keys should be split before being passed through the linear layers, but from most implementation online they are split after.
Are the two approaches similar or is one better than the other?
Yes, I think the paper is quite confusing.
But according to the tutorial of pytorch lightning, you can see it's first linear then split. I think in this way the linear layer will help you to decide how to distribute the values in each head.
https://lightning.ai/docs/pytorch/stable/notebooks/course_UvA-DL/05-transformers-and-MH-attention.html
Edit:
I got more proves that is done before, just look at the GPT architecture here:
https://upload.wikimedia.org/wikipedia/commons/9/91/Full_GPT_architecture.png
You can see clearly it is done before.