Search code examples
deep-learningnlppytorchtransformer-modelattention-model

Should the queries, keys and values of the transformer be split before or after being passed through the linear layers?


I have seen two different implementations of Multi-Head Attention.

  1. In one of the approaches the queries, keys and values are split into heads before being passed through the linear layers as shown below:

    def split_heads(self, x, batch_size):
        return x.reshape(batch_size, -1, self.heads, self.projection_dim)

    def forward(self, queries, keys, values, mask):
        batch_size = queries.size()[0]

        # split queries keys and values into heads
        queries = self.split_heads(queries, batch_size)
        keys = self.split_heads(keys, batch_size)
        values = self.split_heads(values, batch_size)

        queries = self.queries_linear(queries)
        keys = self.keys_linear(keys)
        values = self.values_linear(values)
        #...more code

  1. The second approach is to split the queries, keys and values into heads after passing them through linear layers:
   def forward(self, queries, keys, values, mask=None):
        
        batch_size = q.size(0)
        
        # perform linear operation and split into h heads
        k = self.keys_linear(keys).view(batch_size, -1, self.heads, self.projection_dim)
        q = self.queries_linear(queries).view(batch_size, -1, self.heads, self.projection_dim)
        v = self.values_linear(values).view(batch_size, -1, self.heads, self.projection_dim)
        #...more code
        

According to the paper Attention Is All You Need, from what I can deduce from the image the queries and keys should be split before being passed through the linear layers, but from most implementation online they are split after. Multi-Head Attention

Are the two approaches similar or is one better than the other?


Solution

  • Yes, I think the paper is quite confusing.
    But according to the tutorial of pytorch lightning, you can see it's first linear then split. I think in this way the linear layer will help you to decide how to distribute the values in each head.
    https://lightning.ai/docs/pytorch/stable/notebooks/course_UvA-DL/05-transformers-and-MH-attention.html

    Edit:
    I got more proves that is done before, just look at the GPT architecture here: https://upload.wikimedia.org/wikipedia/commons/9/91/Full_GPT_architecture.png You can see clearly it is done before.