Within the transformer units of BERT, there are modules called Query, Key, and Value, or simply Q,K,V.
Based on the BERT paper and code (particularly in modeling.py), my pseudocode understanding of the forward-pass of an attention module (using Q,K,V) with a single attention-head is as follows:
q_param = a matrix of learned parameters
k_param = a matrix of learned parameters
v_param = a matrix of learned parameters
d = one of the matrix dimensions (scalar value)
def attention(to_tensor, from_tensor, attention_mask):
q = from_tensor * q_param
k = to_tensor * k_param
v = to_tensor * v_param
attention_scores = q * transpose(k) / sqrt(d)
attention_scores += some_function(attention_mask) #attention_mask is usually just ones
attention_probs = dropout(softmax(attention_scores))
context = attention_probs * v
return context
Note that BERT uses "self-attention," so from_tensor
and to_tensor
are the same in BERT; I think both of these are simply the output from the previous layer.
Questions
For your first question, BERT is based on the encoder of the transformer model from the 2017 Vaswani et al "Attention is all you need" paper. The queries, keys, and values metaphor appears already in that paper (although I have learned it is not the source of this idea since the comments above). However, the metaphor actually works best for the other part of the transformer, namely the decoder; this is because as you say the encoder uses self attention, and it seems to me that the queries and keys play a symmetric role in BERT. So perhaps it would be easier to understand this metaphor for the transformer's decoder rather than for BERT.
To my understanding, in the Vaswani et al transformer model, the queries and keys allow all positions of the decoder layer j-1
to attend to all positions of the encoder layer j
via the attention scores. The values are then selected by the queries and keys: the result of the attention layer is the sum of values weighted by the attention scores. The projections of queries and keys determine where the attention for each position is placed. For example, an extreme case could be that the queries are projected by the identity function and the keys are projected to a permutation which moves position i
to position i+1
. The dot product of the keys and queries would allow each position of decoder layer j-1
to attend to the position before it in encoder layer j
. So the decoder layer j-1
is referred to as the queries when, together with the keys, it decides how much each position in decoder layer j-1
(again, but not referred to as the values) will contribute.