Search code examples
tensorflowkeraslstmbidirectionalseq2seq

Understanding states of a bidirectional LSTM in a seq2seq model (tf keras)


I am creating a language model: A seq2seq model with 2 Bidirectional LSTM layers. I have got the model to train and the accuracy seems good, but whilst stuck on figuring out the inference model, I've found myself a bit confused by the states that are returned by each LSTM layer.

I am using this tutorial as a guide, though the example in this link is not using bidriectional layers: https://blog.keras.io/a-ten-minute-introduction-to-sequence-to-sequence-learning-in-keras.html

Note: I am using a pretrained word embedding.

lstm_units = 100

# Set up embedding layer using pretrained weights
embedding_layer = Embedding(total_words+1, emb_dimension, input_length=max_input_len, weights=[embedding_matrix], name="Embedding")

# Encoder
encoder_input_x = Input(shape=(None,), name="Enc_Input")
encoder_embedding_x = embedding_layer(encoder_input_x)
encoder_lstm_x, enc_state_h_fwd, enc_state_c_fwd, enc_state_h_bwd, enc_state_c_bwd = Bidirectional(LSTM(lstm_units, dropout=0.5, return_state=True, name="Enc_LSTM1"), name="Enc_Bi1")(encoder_embedding_x)
encoder_states = [enc_state_h_fwd, enc_state_c_fwd, enc_state_h_bwd, enc_state_c_bwd]

# Decoder
decoder_input_x = Input(shape=(None,), name="Dec_Input")
decoder_embedding_x = embedding_layer(decoder_input_x)
decoder_lstm_layer = Bidirectional(LSTM(lstm_units, return_state=True, return_sequences=True, dropout=0.5, name="Dec_LSTM1"))
decoder_lstm_x, _, _, _, _= decoder_lstm_layer(decoder_embedding_x, initial_state=encoder_states) 
decoder_dense_layer = TimeDistributed(Dense(total_words+1, activation="softmax", name="Dec_Softmax"))
decoder_output_x = decoder_dense_layer(decoder_lstm_x)

model = Model(inputs=[encoder_input_x, decoder_input_x], outputs=decoder_output_x)

model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

I believe diagram of the model looks like this, with 60 time steps.: enter image description here

I want the encoder to pass the enc_state_h_fwd and enc_state_c_fwd forward to the decoder. This connection is highlighted by the orange arrow.

But since the model is bidirectional, I have some questions:

  1. Do I need to pass the decoder states backwards to the encoder? And how would one possibly do this, it seems like a chicken and egg scenario.
  2. The encoder_states that come from the encoder lstm layer output 4 states. h and c states going forward and backward. I feel like the "backward" states are denoted in my diagram by the pink arrow going left out of the encoder. I am passing these to the decoder, but why does it need them? Am I incorrectly connecting the pink arrow on the left to the purple arrow going into the decoder from the right?

Solution

  • This model is not valid. It is set up as a translation model, which during inference would predict one word at a time, starting with the start of sequence token, to predict y1, then looping and feeding in the start of sequence token, y1 to get y2 etc.

    A bidirectional LSTM cannot be used for real time predictions in a many to many prediction unless the entire decoder input is available. In this case, the decoder input is only available after predicting one step at a time, so the first prediction (of y1) is invalid without the rest of the sequence (y2-yt).

    The decoder should therefore not be an LSTM.

    As for the states, the encoder Bidirectional LSTM does indeed output h and c states going forward (orange arrow), and h and c states going backward (pink arrow).

    By concatenating these states and feeding them to the decoder, we can give the decoder more information. This is possible as we do have the entire encoder input at time of inference.

    Also to be noted is that the bidirectional encoder with lstm_units (eg. 100) effectively has 200 lstm units, half going forward, half going backward. To feed these into the decoder, the decoder must have 200 units too.