Search code examples
machine-learningneural-networkpytorchrecurrent-neural-network

What is a good use of the intermediate hidden states of an RNN?


So I've used RNN/LSTMs in three different capacities:

  1. Many to many: Use every output of the final layer to predict the next. Could be classification or regression.
  2. Many to one: Use the final hidden state to perform regression or classification.
  3. One to many: Take a latent space vector, perhaps the final hidden state of an LSTM encoder and use it to generate a sequence (I've done this in the form of an autoencoder).

In none of these cases do I use the intermediate hidden states to generate my final output. Only the last layer outputs in case #1 and only the last layer hidden state in case #2 and #3. However, PyTorch nn.LSTM/RNN returns a vector containing the final hidden state of every layer, so I assume they have some uses.

I'm wondering what some use cases of those intermediate layer states are?


Solution

  • There’s nothing explicitly requiring you to use the last layer only. You could feed in all of the layers to your final classifier MLP for each position in the sequence (or at the end, if you’re classifying the whole sequence).

    As a practical example, consider the ELMo architecture for generating contextualized (that is, token-level) word embeddings. (Paper here: https://www.aclweb.org/anthology/N18-1202/) The representations are the hidden states of a multi-layer biRNN. Figure 2 in the paper shows how different layers differ in usefulness depending on the task. The authors suggest that lower levels encode syntax, while higher levels encode semantics.