Search code examples
tensorflowrecurrent-neural-networksequence-to-sequence

How to use AttentionMechanism with MultiRNNCell and dynamic_decode?


I want to create a multi-layered dynamic RNN-based decoder that uses an attention mechanism. To do this, I first create an attention mechanism:

attention_mechanism = BahdanauAttention(num_units=ATTENTION_UNITS,
                                        memory=encoder_outputs,
                                        normalize=True)

Then I use the AttentionWrapper to wrap a LSTM cell with the attention mechanism:

attention_wrapper = AttentionWrapper(cell=self._create_lstm_cell(DECODER_SIZE),
                                             attention_mechanism=attention_mechanism,
                                             output_attention=False,
                                             alignment_history=True,
                                             attention_layer_size=ATTENTION_LAYER_SIZE)

where self._create_lstm_cell is defined as follows:

@staticmethod
def _create_lstm_cell(cell_size):
    return BasicLSTMCell(cell_size)

I then do some bookkeeping (e.g. creating my MultiRNNCell, creating an initial state, creating a TrainingHelper, etc.)

        attention_zero = attention_wrapper.zero_state(batch_size=tf.flags.FLAGS.batch_size, dtype=tf.float32)

        # define initial state
        initial_state = attention_zero.clone(cell_state=encoder_final_states[0])

        training_helper = TrainingHelper(inputs=self.y,  # feed in ground truth
                                         sequence_length=self.y_lengths)  # feed in sequence lengths

        layered_cell = MultiRNNCell(
            [attention_wrapper] + [ResidualWrapper(self._create_lstm_cell(cell_size=DECODER_SIZE))
                                   for _ in range(NUMBER_OF_DECODER_LAYERS - 1)])

        decoder = BasicDecoder(cell=layered_cell,
                               helper=training_helper,
                               initial_state=initial_state)

        decoder_outputs, decoder_final_state, decoder_final_sequence_lengths = dynamic_decode(decoder=decoder,
                                                                                              maximum_iterations=tf.flags.FLAGS.max_number_of_scans // 12,
                                                                                              impute_finished=True)

But I receive the following error: AttributeError: 'LSTMStateTuple' object has no attribute 'attention'.

What is the correct way to add an attention mechanism to a MultiRNNCell dynamic decoder?


Solution

  • Have you tried using the attention wrapper provided by tf.contrib?

    Here is an example using both an attention wrapper and dropout:

    cells = []
    for i in range(n_layers):                   
        cell = tf.contrib.rnn.LSTMCell(n_hidden, state_is_tuple=True)
    
        cell = tf.contrib.rnn.AttentionCellWrapper(
            cell, attn_length=40, state_is_tuple=True)
    
        cell = tf.contrib.rnn.DropoutWrapper(cell,output_keep_prob=0.5)
        cells.append(cell)
    
    cell = tf.contrib.rnn.MultiRNNCell(cells, state_is_tuple=True)
    init_state = cell.zero_state(batch_size, tf.float32)