Search code examples
tensorflowkerasdeep-learningnlplstm

Verifying the implementation of Multihead Attention in Transformer


I have implemented the MultiAttention head in Transformers. There are so many implementations around so it's confusing. Can someone please verify if my implementation is correct:

DotProductAttention referred from: https://www.tensorflow.org/tutorials/text/transformer#setup

import tensorflow as tf

def scaled_dot_product(q,k,v):
    #calculates Q . K(transpose)
    qkt = tf.matmul(q,k,transpose_b=True)
    #caculates scaling factor
    dk = tf.math.sqrt(tf.cast(q.shape[-1],dtype=tf.float32))
    scaled_qkt = qkt/dk
    softmax = tf.nn.softmax(scaled_qkt,axis=-1)
    
    z = tf.matmul(softmax,v)
    #shape: (m,Tx,depth), same shape as q,k,v
    return z

class MultiAttention(tf.keras.layers.Layer):
    def __init__(self,d_model,num_of_heads):
        super(MultiAttention,self).__init__()
        self.d_model = d_model
        self.num_of_heads = num_of_heads
        self.depth = d_model//num_of_heads
        self.wq = [tf.keras.layers.Dense(self.depth) for i in range(num_of_heads)]
        self.wk = [tf.keras.layers.Dense(self.depth) for i in range(num_of_heads)]
        self.wv = [tf.keras.layers.Dense(self.depth) for i in range(num_of_heads)]
        self.wo = tf.keras.layers.Dense(d_model)
        
    def call(self,x):
        
        multi_attn = []
        for i in range(self.num_of_heads):
            Q = self.wq[i](x)
            K = self.wk[i](x)
            V = self.wv[i](x)
            multi_attn.append(scaled_dot_product(Q,K,V))
            
        multi_head = tf.concat(multi_attn,axis=-1)
        multi_head_attention = self.wo(multi_head)
        return multi_head_attention

#Calling the attention 
multi = MultiAttention(d_model=512,num_of_heads=8)
m = 5; sequence_length = 4; word_embedding_dim = 512
sample_ip = tf.constant(tf.random.normal(shape=(m,sequence_length,word_embedding_dim)))
attn =multi(sample_ip)
#shape of op (attn): (5,4,512)

Solution

  • In your implementation, in scaled_dot_product you scaled with query but according to the original paper, they used key to normalize. Apart from that, this implementation seems Ok but not general.

    class MultiAttention(tf.keras.layers.Layer):
        def __init__(self, num_of_heads, out_dim):
            super(MultiAttention,self).__init__()
            self.out_dim      = out_dim
            self.num_of_heads = num_of_heads
            self.depth        = self.out_dim // self.num_of_heads
            self.wq = [tf.keras.layers.Dense(self.depth) for i in range(num_of_heads)]
            self.wk = [tf.keras.layers.Dense(self.depth) for i in range(num_of_heads)]
            self.wv = [tf.keras.layers.Dense(self.depth) for i in range(num_of_heads)]
            self.wo = tf.keras.layers.Dense(self.out_dim)
            
        def call(self,x):
            multi_attn = []
            for i in range(self.num_of_heads):
                Q = self.wq[i](x)
                K = self.wk[i](x)
                V = self.wv[i](x)
                multi_attn.append(self.scaled_dot_product(Q,K,V))
    
            multi_head = tf.concat(multi_attn, axis=-1)
            multi_head_attention = self.wo(multi_head)
            return multi_head_attention
    
        def scaled_dot_product(self, q,k,v):
            qkt = tf.matmul(q, k, transpose_b=True)
            dk = tf.math.sqrt( tf.cast(k.shape[-1], dtype=tf.float32) )
            scaled_qkt = qkt/dk
            softmax = tf.nn.softmax(scaled_qkt, axis=-1)
            z = tf.matmul(softmax, v)
            return z
    
    multi = MultiAttention(num_of_heads=3, out_dim=32)
    sample_ip = tf.random.normal(shape=(2, 2, 32)); print(sample_ip.shape)
    multi(sample_ip).shape
    

    The general transformer architecture can be demonstrated as follows where the first two linear layers represent query and key and responsible to produce attention weights maps and followed by weighted the value in matrix multiplication fashion.

    Image Source.

    I understand you're trying to minimize the original TF tutorial code but I think you should add reference first to your original question. In the original implementation, they also returned weighted probabilities or scores along with the weighted feature maps. I think you shouldn't skip that.


    The original code that you're following is more general and efficient optimized.

    class MultiHeadAttention(tf.keras.layers.Layer):
        def __init__(self, d_model, num_heads):
            super(MultiHeadAttention, self).__init__()
            self.num_heads = num_heads
            self.d_model = d_model
            assert d_model % self.num_heads == 0
            self.depth = d_model // self.num_heads
            self.wq = tf.keras.layers.Dense(d_model)
            self.wk = tf.keras.layers.Dense(d_model)
            self.wv = tf.keras.layers.Dense(d_model)
            self.dense = tf.keras.layers.Dense(d_model)
    
        def scaled_dot_product_attention(self, q, k, v, mask=None):
            matmul_qk = tf.matmul(q, k, transpose_b=True)  # (..., seq_len_q, seq_len_k)
            # scale matmul_qk
            dk = tf.cast(tf.shape(k)[-1], tf.float32)
            scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
            # add the mask to the scaled tensor.
            if mask is not None: scaled_attention_logits += (mask * -1e9)
            # softmax is normalized on the last axis (seq_len_k) so that the scores
            # add up to 1.
            attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)  # (..., seq_len_q, seq_len_k)
            output = tf.matmul(attention_weights, v)  # (..., seq_len_q, depth_v)
            return output, attention_weights
    
        def split_heads(self, x, batch_size):
            """Split the last dimension into (num_heads, depth).
            Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
            """
            x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
            return tf.transpose(x, perm=[0, 2, 1, 3])
    
        def call(self, v, k, q, mask=None):
            batch_size = tf.shape(q)[0]
            q = self.wq(q)  # (batch_size, seq_len, d_model)
            k = self.wk(k)  # (batch_size, seq_len, d_model)
            v = self.wv(v)  # (batch_size, seq_len, d_model)
    
            q = self.split_heads(q, batch_size)  # (batch_size, num_heads, seq_len_q, depth)
            k = self.split_heads(k, batch_size)  # (batch_size, num_heads, seq_len_k, depth)
            v = self.split_heads(v, batch_size)  # (batch_size, num_heads, seq_len_v, depth)
            # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
            # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
            scaled_attention, attention_weights = self.scaled_dot_product_attention(q, k, v, mask)
    
            scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])  # (batch_size, seq_len_q, num_heads, depth)
            concat_attention = tf.reshape(scaled_attention,  (batch_size, -1, self.d_model))  # (batch_size, seq_len_q, d_model)
            output = self.dense(concat_attention)  # (batch_size, seq_len_q, d_model)
            return output, attention_weights
    

    FYI, in TF 2.4, the tf.keras.layers.MultiHeadAttention layer is officially added.

    layer = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=2)
    input_tensor = tf.keras.Input(shape=[2, 2, 32]); print(input_tensor.shape)
    print(layer(input_tensor, input_tensor).shape)
    

    You can test these two as follows:

    # custom layer MHA
    multi = MultiHeadAttention(d_model=512, num_heads=2)
    y = tf.random.uniform((1, 60, 512))  
    out, attn = multi(y, k=y, q=y, mask=None)
    out.shape, attn.shape
    (TensorShape([1, 60, 512]), TensorShape([1, 2, 60, 60]))
    
    # built-in layer 
    layer = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=2)
    y = tf.random.uniform((1, 60, 512))  
    out, attn = layer(y, y, return_attention_scores=True)
    out.shape, attn.shape
    (TensorShape([1, 60, 512]), TensorShape([1, 2, 60, 60]))