Search code examples
nlptransformer-model

Key matrix redundant in Transformer language models?


Simple implementations of Transformer language models such as this one define 3 matrices K,Q,V to compute keys, queries and values. However matrices K and Q are never used separately: all Transformer computations form their product Q^t K. So I wonder why not learn this product matrix directly instead of splitting it into 2 matrices K and Q.

Part of the answer may come from the size of K and Q, which is d -> n, where d is the dimension of the token embeddings and n is the dimension of keys and queries. The size of Q^t K is d -> d. So learning K and Q separately means optimizing 2*n*d parameters, whereas learning the product Q^t K is d*d parameters. The only useful splitting I see is when n <= d/2, because that's less parameters to optimize. But at the limit case n = d/2, the rank of the product matrix Q^t K is d/2, which is very degenerate. With the same number of parameters d^2, we could learn an unconstrained square matrix. That might learn more flexible and subtle patterns in the training data.

In the Attention is all you need paper, base model page 9, we see d = 512 and n = 64, so the product matrix Q^t K does have degenerate rank. Is reducing the number of parameters the true and unique intent here? Is there a theoretical justification that these degenerate ranks help natural language processing?


Solution

  • The 2 * n * d vs. d * d structure is very similar to how a LORA works, or indeed a WALS model. So as long as 2 * n * d < d * d, this might well be a good way of saving parameters.