Search code examples
pytorchtransformer-model

Why heads share same KQV weights(matrix) in transformer?


    self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
    self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
    self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
    self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query, mask):
    # Get number of training examples
    N = query.shape[0]

    value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

    # Split the embedding into self.heads different pieces
    values = values.reshape(N, value_len, self.heads, self.head_dim)
    keys = keys.reshape(N, key_len, self.heads, self.head_dim)
    query = query.reshape(N, query_len, self.heads, self.head_dim)

    values = self.values(values)  # (N, value_len, heads, head_dim)
    keys = self.keys(keys)  # (N, key_len, heads, head_dim)
    queries = self.queries(query)  # (N, query_len, heads, heads_dim)

    # Einsum does matrix mult. for query*keys for each training example
    # with every other training example, don't be confused by einsum
    # it's just how I like doing matrix multiplication & bmm

    energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])

I have noticed that many implementations of multi-headed attention are similar to the above code. But I am confused as to why, here, the KQV projections for the different heads seem to be shared.

Is it because in back propagation they receive the same signal?


Solution

  • It's not shared. See original implementation (https://github.com/google-research/google-research/blob/6a30ad7a6655fc481ab040ad6e54a92be93a8db3/summae/transformer.py#L73), also implementations in huggingface and pytorch.

    https://github.com/huggingface/transformers/blob/c55d6e4e10ce2d9c37e5f677f0842b04ef8b73f3/src/transformers/models/bert/modeling_bert.py#L251

    https://github.com/pytorch/pytorch/blob/f5bfa4d0888e6cd5984092b38cb8b10609558d05/torch/nn/modules/activation.py#L946