Search code examples
tensorflowlstmtensorflow2.0recurrent-neural-network

Tensorflow RNN LSTM output explanation


I have two questions related the tf.keras.layers.LSTMCell. Let's look at the following code:

inputs = tf.random.normal([32, 10, 8])
rnn1 = tf.keras.layers.RNN(tf.keras.layers.LSTMCell(4))
output = rnn1(inputs)

rnn2 = tf.keras.layers.RNN(
   tf.keras.layers.LSTMCell(4),
   return_sequences=True,
   return_state=True)
whole_seq_output, final_memory_state, final_carry_state = rnn2(inputs)

From the outputs of rnn2, I can see that the final_memory_state is contained in the whole_seq_output:

tf.reduce_all(whole_seq_output[:,-1,:]==final_memory_state)
<tf.Tensor: shape=(), dtype=bool, numpy=True>

Hence, I think the final_memory_state is the final cell state, while the whole_seq_output contains all the cell states. Also, the final_carry_state is the final hidden state. The cell state and hidden state are referred to as C_t and h_t in this well-known tutorial. Is my understanding correct?

Also, from rnn1, the output is not one of final_memory_state or final_carry_state:

>>> tf.reduce_all(output == final_carry_state)
<tf.Tensor: shape=(), dtype=bool, numpy=False>
>>> tf.reduce_all(output == final_memory_state)
<tf.Tensor: shape=(), dtype=bool, numpy=False>

I think the only difference between the rnn1 and rnn2 is how the values are returned, so the output should be one of final_memory_state or final_carry_state. Could you help explain?


Solution

  • After testing several times, it turns out that the whole_seq_output contains all the outputs at different time steps while the final_memory_state is the output at the final time step. They refer to the h_t in the aforementioned tutorial. Also, final_carry_state is the cell state (i.e., the C_t in the tutorial). Lastly, output is indeed final_memory_state. Their values should be the same if the same cell is used (I used two different cells).

    inputs = tf.random.normal([32, 10, 8])
    cell = tf.keras.layers.LSTMCell(4)
    rnn1 = tf.keras.layers.RNN(cell)
    output = rnn1(inputs)
    
    rnn2 = tf.keras.layers.RNN(
       cell,
       return_sequences=True,
       return_state=True)
    whole_seq_output, final_memory_state, final_carry_state = rnn2(inputs)