Search code examples

Output shapes of Keras AdditiveAttention Layer

Trying to use the AdditiveAttention layer in Keras. On manual implementation of the layer from tensorflow tutorial

import tensorflow as tf 

class BahdanauAttention(tf.keras.layers.Layer):
  def __init__(self, units):
    super(BahdanauAttention, self).__init__()
    self.W1 = tf.keras.layers.Dense(units)
    self.W2 = tf.keras.layers.Dense(units)
    self.V = tf.keras.layers.Dense(1)

  def call(self, query, values):
    query_with_time_axis = tf.expand_dims(query, 1)
    score = self.V(tf.nn.tanh(
        self.W1(query_with_time_axis) + self.W2(values)))
    attention_weights = tf.nn.softmax(score, axis=1)

    # context_vector shape after sum == (batch_size, hidden_size)
    context_vector = attention_weights * values
    context_vector = tf.reduce_sum(context_vector, axis=1)
    return context_vector, attention_weights

The shape of the context_vector is (batch_size, units)

Whereas using the same AdditiveAttention layer from keras built-in

from tensorflow.keras.layers import AdditiveAttention

the shape of the context_vector = [batch_size, Tq, dim]

Any suggestions on what is causing this OP shape difference will be useful.


  • Both implementations are mutually similar except for some variation. The implementation of BahdanauAttention in that tutorial is a kinda simplified and adapted version and uses some linear transformation. The return shape of context_vector that you're wondering is nothing but the issue of input data shape. Here is some demonstration, let's see the tutorial implementation:

    class BahdanauAttention(tf.keras.layers.Layer):
      def __init__(self, units):
        super(BahdanauAttention, self).__init__()
        self.W1 = tf.keras.layers.Dense(units)
        self.W2 = tf.keras.layers.Dense(units)
        self.V  = tf.keras.layers.Dense(1)
      def call(self, query, values):
        query_with_time_axis = tf.expand_dims(query, 1)
        score = self.V(tf.nn.tanh(self.W1(query_with_time_axis) + self.W2(values)))
        attention_weights = tf.nn.softmax(score, axis=1)
        context_vector = attention_weights * values
        context_vector = tf.reduce_sum(context_vector, axis=1)
        return context_vector, attention_weights

    Now, we pass some input to it, 3D and 2D.

    attention_layer = BahdanauAttention(10)
    y = tf.random.uniform((2, 60, 512))  
    out, attn = attention_layer(y, y)
    out.shape , attn.shape
    (TensorShape([2, 60, 512]), TensorShape([2, 2, 60, 1]))
    y = tf.random.uniform((2, 512))  
    out, attn = attention_layer(y, y)
    out.shape , attn.shape
    (TensorShape([2, 512]), TensorShape([2, 2, 1]))

    Now, passing the same inputs to the built-in AdditiveAttention and see what we'll get

    buit_attn = tf.keras.layers.AdditiveAttention()
    y = tf.random.uniform((2, 60, 512))  
    out, attn = buit_attn([y, y], return_attention_scores=True)
    out.shape , attn.shape
    (TensorShape([2, 60, 512]), TensorShape([2, 60, 60]))
    y = tf.random.uniform((2, 512))  
    out, attn = buit_attn([y, y], return_attention_scores=True)
    out.shape , attn.shape
    (TensorShape([2, 512]), TensorShape([2, 2]))

    So, the shape of the context_vector is comparable here, but not the shape of attention_weights. The reason is, as we mentioned, the implementation of that tutorial kinda modified and adopted I believe. If we look at the calculation of BahdanauAttention or AdditiveAttention, we will get:

    1. Reshape query and value into shapes [batch_size, Tq, 1, dim] and [batch_size, 1, Tv, dim] respectively.
    2. Calculate scores with shape [batch_size, Tq, Tv] as a non-linear sum: scores = tf.reduce_sum(tf.tanh(query + value), axis=-1)
    3. Use scores to calculate a distribution with shape [batch_size, Tq, Tv]: distribution = tf.nn.softmax(scores).
    4. Use distribution to create a linear combination of values with shape batch_size, Tq, dim]: return tf.matmul(distribution, value).

    And I think the implementation in that tutorials is a bit different for calculating the attention weight features. If we follow the above approach (1 to 4), we will get the same output shape for attention_weights as well. Here is how, (but not here is just a demonstration purpose, not general.)

    class BahdanauAttention(tf.keras.layers.Layer):
      def __init__(self, units):
        super(BahdanauAttention, self).__init__()
        self.W1 = tf.keras.layers.Dense(units)
        self.W2 = tf.keras.layers.Dense(units)
        self.V = tf.keras.layers.Dense(1)
      def call(self, query, values):
        query_with_time_axis = tf.expand_dims(query, 2)  # [batch_size, Tq, 1, dim]
        value_with_time_axis = tf.expand_dims(values, 1) # [batch_size, 1, Tv, dim]
        scores = tf.reduce_sum(tf.tanh(query_with_time_axis + 
                                       value_with_time_axis), axis=-1)
        distribution = tf.nn.softmax(scores)
        return tf.matmul(distribution, values), distribution

    Now, if we pass the same input, we will get the same output shape from both implementations. However, in general, use cases, the built-in implementation should be picked.

    attention_layer = BahdanauAttention(10)
    y = tf.random.uniform((2, 60, 512))  
    out, attn = attention_layer(y, y)
    out.shape , attn.shape
    (TensorShape([2, 60, 512]), TensorShape([2, 60, 60]))
    buit_attn = tf.keras.layers.AdditiveAttention()
    y = tf.random.uniform((2, 60, 512))  
    out, attn = buit_attn([y, y], return_attention_scores=True)
    out.shape , attn.shape
    (TensorShape([2, 60, 512]), TensorShape([2, 60, 60]))