self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.fc_out = nn.Linear(heads * self.head_dim, embed_size)
def forward(self, values, keys, query, mask):
# Get number of training examples
N = query.shape[0]
value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
# Split the embedding into self.heads different pieces
values = values.reshape(N, value_len, self.heads, self.head_dim)
keys = keys.reshape(N, key_len, self.heads, self.head_dim)
query = query.reshape(N, query_len, self.heads, self.head_dim)
values = self.values(values) # (N, value_len, heads, head_dim)
keys = self.keys(keys) # (N, key_len, heads, head_dim)
queries = self.queries(query) # (N, query_len, heads, heads_dim)
# Einsum does matrix mult. for query*keys for each training example
# with every other training example, don't be confused by einsum
# it's just how I like doing matrix multiplication & bmm
energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
I have noticed that many implementations of multi-headed attention are similar to the above code. But I am confused as to why, here, the KQV projections for the different heads seem to be shared.
Is it because in back propagation they receive the same signal?
It's not shared. See original implementation (https://github.com/google-research/google-research/blob/6a30ad7a6655fc481ab040ad6e54a92be93a8db3/summae/transformer.py#L73), also implementations in huggingface and pytorch.