I have two questions related the tf.keras.layers.LSTMCell. Let's look at the following code:
inputs = tf.random.normal([32, 10, 8])
rnn1 = tf.keras.layers.RNN(tf.keras.layers.LSTMCell(4))
output = rnn1(inputs)
rnn2 = tf.keras.layers.RNN(
tf.keras.layers.LSTMCell(4),
return_sequences=True,
return_state=True)
whole_seq_output, final_memory_state, final_carry_state = rnn2(inputs)
From the outputs of rnn2
, I can see that the final_memory_state
is contained in the whole_seq_output
:
tf.reduce_all(whole_seq_output[:,-1,:]==final_memory_state)
<tf.Tensor: shape=(), dtype=bool, numpy=True>
Hence, I think the final_memory_state
is the final cell state, while the whole_seq_output
contains all the cell states. Also, the final_carry_state
is the final hidden state. The cell state and hidden state are referred to as C_t and h_t in this well-known tutorial. Is my understanding correct?
Also, from rnn1
, the output
is not one of final_memory_state
or final_carry_state
:
>>> tf.reduce_all(output == final_carry_state)
<tf.Tensor: shape=(), dtype=bool, numpy=False>
>>> tf.reduce_all(output == final_memory_state)
<tf.Tensor: shape=(), dtype=bool, numpy=False>
I think the only difference between the rnn1
and rnn2
is how the values are returned, so the output
should be one of final_memory_state
or final_carry_state
. Could you help explain?
After testing several times, it turns out that the whole_seq_output
contains all the outputs at different time steps while the final_memory_state
is the output at the final time step. They refer to the h_t
in the aforementioned tutorial.
Also, final_carry_state
is the cell state (i.e., the C_t
in the tutorial).
Lastly, output
is indeed final_memory_state
. Their values should be the same if the same cell is used (I used two different cells).
inputs = tf.random.normal([32, 10, 8])
cell = tf.keras.layers.LSTMCell(4)
rnn1 = tf.keras.layers.RNN(cell)
output = rnn1(inputs)
rnn2 = tf.keras.layers.RNN(
cell,
return_sequences=True,
return_state=True)
whole_seq_output, final_memory_state, final_carry_state = rnn2(inputs)