Search code examples
pythonpython-3.xnlppytorchattention-model

Number of learnable parameters of MultiheadAttention


While testing (using PyTorch's MultiheadAttention), I noticed that increasing or decreasing the number of heads of the multi-head attention does not change the total number of learnable parameters of my model.

Is this behavior correct? And if so, why?

Shouldn't the number of heads affect the number of parameters the model can learn?


Solution

  • The standard implementation of multi-headed attention divides the model's dimensionality by the number of attention heads.

    A model of dimensionality d with a single attention head would project embeddings to a single triplet of d-dimensional query, key and value tensors (each projection counting d2 parameters, excluding biases, for a total of 3d2).

    A model of the same dimensionality with k attention heads would project embeddings to k triplets of d/k-dimensional query, key and value tensors (each projection counting d×d/k=d2/k parameters, excluding biases, for a total of 3kd2/k=3d2).


    References:

    From the original paper: enter image description here

    The Pytorch implementation you cited: enter image description here