Search code examples
kerasdeep-learningnlplstmseq2seq

Keras seq2seq model Output Shapes


I am working on keras seq2seq example here:https://blog.keras.io/a-ten-minute-introduction-to-sequence-to-sequence-learning-in-keras.html What I have understood from the text is in decoder model each cell's output is input to the next cell. However I didnt understand implementing this recursion to the model.In the link it makes the decoder model as follows.

decoder_model = Model(
        [decoder_inputs] + decoder_states_inputs,
        [decoder_outputs] + decoder_states)

How does this syntax work to tell the model that each cells output is input to next cell? In general how does this syntax work?

EDIT: When you check keras.Model documentation you will realize that a model can take a list of keras.Input objects as input argument, notice that [decoder_inputs] + decoder_states_inputs is a list.


Solution

  • If you look at the documentation for the Keras Model class here, you'll see that the Model() function takes in inputs and outputs as its first and second arguments respectively (Model(inputs, outputs)). This specifies the input and output layers of the model (in your case, a decoder that will be used in the inference loop of the decode_sequence() function at the end of the article you linked).

    To elaborate more on the code snippet you posted, you are providing decoder_inputs and decoder_states_inputs together as the inputs argument of Model(inputs, outputs) to specify the input layer of the decoder model:

    • decoder_inputs is an Input object (Keras tensor) with length num_decoder_tokens, instantiated using the Input() function (see Input) that simply accepts the input tokens (characters).

    • Similarly, decoder_states_inputs is a list of two Input tensors for the decoder's hidden input state and cell state, both of length latent_dim.

    And again, you provide the decoder_outputs and decoder_states together as the outputs argument of Model(inputs, outputs) to specify the output layer of the model:

    • decoder_outputs ends up being a densely connected NN layer used for output activation (see Dense).
    • decoder_states is a list containing the hidden state state_h and cell state state_c of the decoder_lstm.