Search code examples
tensorflowdeep-learningneural-networknlp

How does Masking work in the scaled_dot_product_attention of Language understanding Transformers?


I have been following Tensorflow's tutorial on Transformers for Language understanding. (here). However I'm a bit confused about Masks used in the function scaled_dot_product_attention. I know what are masks used for but I do know understand how they work in this function for example.

When I followed the tutorial I understood that the mask will have a matrix indicating which elements are padding elements ( value 1 in the masking matrix) and which are not ( value 0 in the masking matrix). for example :

[0 , 0 , 1 
 1 , 0 , 0 
 0 , 1 , 0 ]

However I can see that the function scaled_dot_product_attention tries to update the padded elements with a very large ( or small ) number which is -1e9 ( Negative 1 Billion ). This can be seen in the below line of the mentioned function :

      if mask is not None:
    scaled_attention_logits += (mask * -1e9)

Why is this done ? and how does this mathematically leads to ignoring these values ? . Below is the implementation shown in the tutorial :

   def scaled_dot_product_attention(q, k, v, mask):
  """Calculate the attention weights.
  q, k, v must have matching leading dimensions.
  k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v.
  The mask has different shapes depending on its type(padding or look ahead) 
  but it must be broadcastable for addition.

  Args:
    q: query shape == (..., seq_len_q, depth)
    k: key shape == (..., seq_len_k, depth)
    v: value shape == (..., seq_len_v, depth_v)
    mask: Float tensor with shape broadcastable 
          to (..., seq_len_q, seq_len_k). Defaults to None.

  Returns:
    output, attention_weights
  """

  matmul_qk = tf.matmul(q, k, transpose_b=True)  # (..., seq_len_q, seq_len_k)

  # scale matmul_qk
  dk = tf.cast(tf.shape(k)[-1], tf.float32)
  scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)

  # add the mask to the scaled tensor.
  if mask is not None:
    scaled_attention_logits += (mask * -1e9)  

  # softmax is normalized on the last axis (seq_len_k) so that the scores
  # add up to 1.
  attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)  # (..., seq_len_q, seq_len_k)

  output = tf.matmul(attention_weights, v)  # (..., seq_len_q, depth_v)

  return output, attention_weights

Solution

  • OK , so the value -1e9 resembles negative infinity. Therefor the softmax function will produce a probability of 0 to such elements and will be ignored when calculating the attention values.