Search code examples
pythontensorflowkerasdeep-learningattention-model

How to build a attention model with keras?


I am trying to understand attention model and also build one myself. After many searches I came across this website which had an atteniton model coded in keras and also looks simple. But when I tried to build that same model in my machine its giving multiple argument error. The error was due to the mismatched argument passing in class Attention. In the website's attention class it's asking for one argument but it initiates the attention object with two arguments.

import tensorflow as tf

max_len = 200
rnn_cell_size = 128
vocab_size=250

class Attention(tf.keras.Model):
    def __init__(self, units):
        super(Attention, self).__init__()
        self.W1 = tf.keras.layers.Dense(units)
        self.W2 = tf.keras.layers.Dense(units)
        self.V = tf.keras.layers.Dense(1)
    def call(self, features, hidden):
        hidden_with_time_axis = tf.expand_dims(hidden, 1)
        score = tf.nn.tanh(self.W1(features) + self.W2(hidden_with_time_axis))
        attention_weights = tf.nn.softmax(self.V(score), axis=1)
        context_vector = attention_weights * features
        context_vector = tf.reduce_sum(context_vector, axis=1)
        return context_vector, attention_weights

sequence_input = tf.keras.layers.Input(shape=(max_len,), dtype='int32')

embedded_sequences = tf.keras.layers.Embedding(vocab_size, 128, input_length=max_len)(sequence_input)

lstm = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM
                                     (rnn_cell_size,
                                      dropout=0.3,
                                      return_sequences=True,
                                      return_state=True,
                                      recurrent_activation='relu',
                                      recurrent_initializer='glorot_uniform'), name="bi_lstm_0")(embedded_sequences)

lstm, forward_h, forward_c, backward_h, backward_c = tf.keras.layers.Bidirectional \
    (tf.keras.layers.LSTM
     (rnn_cell_size,
      dropout=0.2,
      return_sequences=True,
      return_state=True,
      recurrent_activation='relu',
      recurrent_initializer='glorot_uniform'))(lstm)

state_h = tf.keras.layers.Concatenate()([forward_h, backward_h])
state_c = tf.keras.layers.Concatenate()([forward_c, backward_c])

#  PROBLEM IN THIS LINE
context_vector, attention_weights = Attention(lstm, state_h)

output = keras.layers.Dense(1, activation='sigmoid')(context_vector)

model = keras.Model(inputs=sequence_input, outputs=output)

# summarize layers
print(model.summary())

How can I make this model work?


Solution

  • There is a problem with the way you initialize attention layer and pass parameters. You should specify the number of attention layer units in this place and modify the way of passing in parameters:

    context_vector, attention_weights = Attention(32)(lstm, state_h)
    

    The result:

    __________________________________________________________________________________________________
    Layer (type)                    Output Shape         Param #     Connected to                     
    ==================================================================================================
    input_1 (InputLayer)            (None, 200)          0                                            
    __________________________________________________________________________________________________
    embedding (Embedding)           (None, 200, 128)     32000       input_1[0][0]                    
    __________________________________________________________________________________________________
    bi_lstm_0 (Bidirectional)       [(None, 200, 256), ( 263168      embedding[0][0]                  
    __________________________________________________________________________________________________
    bidirectional (Bidirectional)   [(None, 200, 256), ( 394240      bi_lstm_0[0][0]                  
                                                                     bi_lstm_0[0][1]                  
                                                                     bi_lstm_0[0][2]                  
                                                                     bi_lstm_0[0][3]                  
                                                                     bi_lstm_0[0][4]                  
    __________________________________________________________________________________________________
    concatenate (Concatenate)       (None, 256)          0           bidirectional[0][1]              
                                                                     bidirectional[0][3]              
    __________________________________________________________________________________________________
    attention (Attention)           [(None, 256), (None, 16481       bidirectional[0][0]              
                                                                     concatenate[0][0]                
    __________________________________________________________________________________________________
    dense_3 (Dense)                 (None, 1)            257         attention[0][0]                  
    ==================================================================================================
    Total params: 706,146
    Trainable params: 706,146
    Non-trainable params: 0
    __________________________________________________________________________________________________
    None