While testing (using PyTorch's MultiheadAttention), I noticed that increasing or decreasing the number of heads of the multi-head attention does not change the total number of learnable parameters of my model.
Is this behavior correct? And if so, why?
Shouldn't the number of heads affect the number of parameters the model can learn?
The standard implementation of multi-headed attention divides the model's dimensionality by the number of attention heads.
A model of dimensionality d with a single attention head would project embeddings to a single triplet of d-dimensional query, key and value tensors (each projection counting d2 parameters, excluding biases, for a total of 3d2).
A model of the same dimensionality with k attention heads would project embeddings to k triplets of d/k-dimensional query, key and value tensors (each projection counting d×d/k=d2/k parameters, excluding biases, for a total of 3kd2/k=3d2).
References: