Search code examples
tensorflowkerasdeep-learningneural-networkattention-model

Output shapes of Keras AdditiveAttention Layer


Trying to use the AdditiveAttention layer in Keras. On manual implementation of the layer from tensorflow tutorial https://www.tensorflow.org/tutorials/text/nmt_with_attention

import tensorflow as tf 

class BahdanauAttention(tf.keras.layers.Layer):
  def __init__(self, units):
    super(BahdanauAttention, self).__init__()
    self.W1 = tf.keras.layers.Dense(units)
    self.W2 = tf.keras.layers.Dense(units)
    self.V = tf.keras.layers.Dense(1)

  def call(self, query, values):
    query_with_time_axis = tf.expand_dims(query, 1)
    score = self.V(tf.nn.tanh(
        self.W1(query_with_time_axis) + self.W2(values)))
    attention_weights = tf.nn.softmax(score, axis=1)

    # context_vector shape after sum == (batch_size, hidden_size)
    context_vector = attention_weights * values
    context_vector = tf.reduce_sum(context_vector, axis=1)
    return context_vector, attention_weights

The shape of the context_vector is (batch_size, units)

Whereas using the same AdditiveAttention layer from keras built-in

from tensorflow.keras.layers import AdditiveAttention

the shape of the context_vector = [batch_size, Tq, dim]

Any suggestions on what is causing this OP shape difference will be useful.


Solution

  • Both implementations are mutually similar except for some variation. The implementation of BahdanauAttention in that tutorial is a kinda simplified and adapted version and uses some linear transformation. The return shape of context_vector that you're wondering is nothing but the issue of input data shape. Here is some demonstration, let's see the tutorial implementation:

    class BahdanauAttention(tf.keras.layers.Layer):
      def __init__(self, units):
        super(BahdanauAttention, self).__init__()
        self.W1 = tf.keras.layers.Dense(units)
        self.W2 = tf.keras.layers.Dense(units)
        self.V  = tf.keras.layers.Dense(1)
    
      def call(self, query, values):
        query_with_time_axis = tf.expand_dims(query, 1)
        score = self.V(tf.nn.tanh(self.W1(query_with_time_axis) + self.W2(values)))
        attention_weights = tf.nn.softmax(score, axis=1)
        context_vector = attention_weights * values
        context_vector = tf.reduce_sum(context_vector, axis=1)
        return context_vector, attention_weights
    

    Now, we pass some input to it, 3D and 2D.

    attention_layer = BahdanauAttention(10)
    
    y = tf.random.uniform((2, 60, 512))  
    out, attn = attention_layer(y, y)
    out.shape , attn.shape
    (TensorShape([2, 60, 512]), TensorShape([2, 2, 60, 1]))
    
    y = tf.random.uniform((2, 512))  
    out, attn = attention_layer(y, y)
    out.shape , attn.shape
    (TensorShape([2, 512]), TensorShape([2, 2, 1]))
    

    Now, passing the same inputs to the built-in AdditiveAttention and see what we'll get

    buit_attn = tf.keras.layers.AdditiveAttention()
    
    y = tf.random.uniform((2, 60, 512))  
    out, attn = buit_attn([y, y], return_attention_scores=True)
    out.shape , attn.shape
    (TensorShape([2, 60, 512]), TensorShape([2, 60, 60]))
    
    y = tf.random.uniform((2, 512))  
    out, attn = buit_attn([y, y], return_attention_scores=True)
    out.shape , attn.shape
    (TensorShape([2, 512]), TensorShape([2, 2]))
    

    So, the shape of the context_vector is comparable here, but not the shape of attention_weights. The reason is, as we mentioned, the implementation of that tutorial kinda modified and adopted I believe. If we look at the calculation of BahdanauAttention or AdditiveAttention, we will get:

    1. Reshape query and value into shapes [batch_size, Tq, 1, dim] and [batch_size, 1, Tv, dim] respectively.
    2. Calculate scores with shape [batch_size, Tq, Tv] as a non-linear sum: scores = tf.reduce_sum(tf.tanh(query + value), axis=-1)
    3. Use scores to calculate a distribution with shape [batch_size, Tq, Tv]: distribution = tf.nn.softmax(scores).
    4. Use distribution to create a linear combination of values with shape batch_size, Tq, dim]: return tf.matmul(distribution, value).

    And I think the implementation in that tutorials is a bit different for calculating the attention weight features. If we follow the above approach (1 to 4), we will get the same output shape for attention_weights as well. Here is how, (but not here is just a demonstration purpose, not general.)

    class BahdanauAttention(tf.keras.layers.Layer):
      def __init__(self, units):
        super(BahdanauAttention, self).__init__()
        self.W1 = tf.keras.layers.Dense(units)
        self.W2 = tf.keras.layers.Dense(units)
        self.V = tf.keras.layers.Dense(1)
    
      def call(self, query, values):
        query_with_time_axis = tf.expand_dims(query, 2)  # [batch_size, Tq, 1, dim]
        value_with_time_axis = tf.expand_dims(values, 1) # [batch_size, 1, Tv, dim]
        scores = tf.reduce_sum(tf.tanh(query_with_time_axis + 
                                       value_with_time_axis), axis=-1)
        distribution = tf.nn.softmax(scores)
        return tf.matmul(distribution, values), distribution
    

    Now, if we pass the same input, we will get the same output shape from both implementations. However, in general, use cases, the built-in implementation should be picked.

    attention_layer = BahdanauAttention(10)
    
    y = tf.random.uniform((2, 60, 512))  
    out, attn = attention_layer(y, y)
    out.shape , attn.shape
    (TensorShape([2, 60, 512]), TensorShape([2, 60, 60]))
    
    buit_attn = tf.keras.layers.AdditiveAttention()
    y = tf.random.uniform((2, 60, 512))  
    out, attn = buit_attn([y, y], return_attention_scores=True)
    out.shape , attn.shape
    (TensorShape([2, 60, 512]), TensorShape([2, 60, 60]))