Search code examples
tensorflownlptransformer-modelattention-model

Understanding dimensions in MultiHeadAttention layer of Tensorflow


I'm learning multi-head attention with this article. As the writer claimed, the structure of MHA (by the original paper) is as follows:

enter image description here

But the MultiHeadAttention layer of Tensorflow seems to be more flexible:

  1. It does not require key_dim * num_heads = embed_dim. Like:
layer = tf.keras.layers.MultiHeadAttention(num_heads = 2, key_dim = 4)
x = tf.keras.Input(shape=[3, 5])
layer(x, x)
# no error

Does the depth of the weight matrix in tf.MHA layer set to key_dim * num_heads regardless of embed_dim? So that Q/K/V can still be properly split by num_heads.

  1. However, the output depth of tf.MHA layer is (by default) guaranteed to be embed_dim. So there is a final dense layer with embed_dim nodes to ensure the dimension?

Solution

  • Yes, for 1 & 2. You can probe the weights by:

    layer = tf.keras.layers.MultiHeadAttention(num_heads = 2, key_dim = 4, use_bias=False) #Set use_bias=False for simplicity.
    x = tf.keras.Input(shape=[3, 5])
    layer(x, x)
    

    Get the weights associated,

    weight_names = ['query', 'keys',  'values', 'proj']
    for name, out in zip(weight_names,layer.get_weights()):
        print(name, out.shape)
    

    Output shapes:

    query (5, 2, 4) # (embed_dim, num_heads, key_dim)
    keys (5, 2, 4)  # (embed_dim, num_heads, key_dim)
    values (5, 2, 4) # (embed_dim, num_heads, value_dim/key_dim)
    proj (2, 4, 5)  # (num_heads, key_dim, embed_dim)