I have implemented the MultiAttention head
in Transformers
. There are so many implementations around so it's confusing. Can someone please verify if my implementation is correct:
DotProductAttention referred from: https://www.tensorflow.org/tutorials/text/transformer#setup
import tensorflow as tf
def scaled_dot_product(q,k,v):
#calculates Q . K(transpose)
qkt = tf.matmul(q,k,transpose_b=True)
#caculates scaling factor
dk = tf.math.sqrt(tf.cast(q.shape[-1],dtype=tf.float32))
scaled_qkt = qkt/dk
softmax = tf.nn.softmax(scaled_qkt,axis=-1)
z = tf.matmul(softmax,v)
#shape: (m,Tx,depth), same shape as q,k,v
return z
class MultiAttention(tf.keras.layers.Layer):
def __init__(self,d_model,num_of_heads):
self.d_model = d_model
self.num_of_heads = num_of_heads
self.depth = d_model//num_of_heads
self.wq = [tf.keras.layers.Dense(self.depth) for i in range(num_of_heads)]
self.wk = [tf.keras.layers.Dense(self.depth) for i in range(num_of_heads)]
self.wv = [tf.keras.layers.Dense(self.depth) for i in range(num_of_heads)]
self.wo = tf.keras.layers.Dense(d_model)
def call(self,x):
multi_attn = []
for i in range(self.num_of_heads):
Q = self.wq[i](x)
K = self.wk[i](x)
V = self.wv[i](x)
multi_head = tf.concat(multi_attn,axis=-1)
multi_head_attention = self.wo(multi_head)
return multi_head_attention
#Calling the attention
multi = MultiAttention(d_model=512,num_of_heads=8)
m = 5; sequence_length = 4; word_embedding_dim = 512
sample_ip = tf.constant(tf.random.normal(shape=(m,sequence_length,word_embedding_dim)))
attn =multi(sample_ip)
#shape of op (attn): (5,4,512)
In your implementation, in scaled_dot_product
you scaled with query
but according to the original paper, they used key
to normalize. Apart from that, this implementation seems Ok but not general.
class MultiAttention(tf.keras.layers.Layer):
def __init__(self, num_of_heads, out_dim):
self.out_dim = out_dim
self.num_of_heads = num_of_heads
self.depth = self.out_dim // self.num_of_heads
self.wq = [tf.keras.layers.Dense(self.depth) for i in range(num_of_heads)]
self.wk = [tf.keras.layers.Dense(self.depth) for i in range(num_of_heads)]
self.wv = [tf.keras.layers.Dense(self.depth) for i in range(num_of_heads)]
self.wo = tf.keras.layers.Dense(self.out_dim)
def call(self,x):
multi_attn = []
for i in range(self.num_of_heads):
Q = self.wq[i](x)
K = self.wk[i](x)
V = self.wv[i](x)
multi_head = tf.concat(multi_attn, axis=-1)
multi_head_attention = self.wo(multi_head)
return multi_head_attention
def scaled_dot_product(self, q,k,v):
qkt = tf.matmul(q, k, transpose_b=True)
dk = tf.math.sqrt( tf.cast(k.shape[-1], dtype=tf.float32) )
scaled_qkt = qkt/dk
softmax = tf.nn.softmax(scaled_qkt, axis=-1)
z = tf.matmul(softmax, v)
return z
multi = MultiAttention(num_of_heads=3, out_dim=32)
sample_ip = tf.random.normal(shape=(2, 2, 32)); print(sample_ip.shape)
The general transformer architecture can be demonstrated as follows where the first two linear layers represent query
and key
and responsible to produce attention weights maps and followed by weighted the value
in matrix multiplication fashion.
I understand you're trying to minimize the original TF tutorial code but I think you should add reference first to your original question. In the original implementation, they also returned weighted probabilities or scores along with the weighted feature maps. I think you shouldn't skip that.
The original code that you're following is more general and efficient optimized.
class MultiHeadAttention(tf.keras.layers.Layer):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
self.num_heads = num_heads
self.d_model = d_model
assert d_model % self.num_heads == 0
self.depth = d_model // self.num_heads
self.wq = tf.keras.layers.Dense(d_model)
self.wk = tf.keras.layers.Dense(d_model)
self.wv = tf.keras.layers.Dense(d_model)
self.dense = tf.keras.layers.Dense(d_model)
def scaled_dot_product_attention(self, q, k, v, mask=None):
matmul_qk = tf.matmul(q, k, transpose_b=True) # (..., seq_len_q, seq_len_k)
# scale matmul_qk
dk = tf.cast(tf.shape(k)[-1], tf.float32)
scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
# add the mask to the scaled tensor.
if mask is not None: scaled_attention_logits += (mask * -1e9)
# softmax is normalized on the last axis (seq_len_k) so that the scores
# add up to 1.
attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1) # (..., seq_len_q, seq_len_k)
output = tf.matmul(attention_weights, v) # (..., seq_len_q, depth_v)
return output, attention_weights
def split_heads(self, x, batch_size):
"""Split the last dimension into (num_heads, depth).
Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, v, k, q, mask=None):
batch_size = tf.shape(q)[0]
q = self.wq(q) # (batch_size, seq_len, d_model)
k = self.wk(k) # (batch_size, seq_len, d_model)
v = self.wv(v) # (batch_size, seq_len, d_model)
q = self.split_heads(q, batch_size) # (batch_size, num_heads, seq_len_q, depth)
k = self.split_heads(k, batch_size) # (batch_size, num_heads, seq_len_k, depth)
v = self.split_heads(v, batch_size) # (batch_size, num_heads, seq_len_v, depth)
# scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
# attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
scaled_attention, attention_weights = self.scaled_dot_product_attention(q, k, v, mask)
scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3]) # (batch_size, seq_len_q, num_heads, depth)
concat_attention = tf.reshape(scaled_attention, (batch_size, -1, self.d_model)) # (batch_size, seq_len_q, d_model)
output = self.dense(concat_attention) # (batch_size, seq_len_q, d_model)
return output, attention_weights
FYI, in TF 2.4
, the tf.keras.layers.MultiHeadAttention
layer is officially added.
layer = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=2)
input_tensor = tf.keras.Input(shape=[2, 2, 32]); print(input_tensor.shape)
print(layer(input_tensor, input_tensor).shape)
You can test these two as follows:
# custom layer MHA
multi = MultiHeadAttention(d_model=512, num_heads=2)
y = tf.random.uniform((1, 60, 512))
out, attn = multi(y, k=y, q=y, mask=None)
out.shape, attn.shape
(TensorShape([1, 60, 512]), TensorShape([1, 2, 60, 60]))
# built-in layer
layer = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=2)
y = tf.random.uniform((1, 60, 512))
out, attn = layer(y, y, return_attention_scores=True)
out.shape, attn.shape
(TensorShape([1, 60, 512]), TensorShape([1, 2, 60, 60]))