Search code examples
pythondeep-learningpytorchlstm

Query regarding Pytorch LSTM code snippet


In the Stack Overflow thread How can i add a Bi-LSTM layer on top of bert model?, there is a line of code:

hidden = torch.cat((lstm_output[:,-1, :256],lstm_output[:,0, 256:]),dim=-1)

Can someone explain why the concatenation of last and first tokens and not any other? What would these two tokens contain that they were chosen?


Solution

  • In bidirectional models, hidden states gets concatenated at each step; so, the line basically concatenates the first :256 units of the last hidden state in the positive direction (-1) to the last 256: units of the last hidden state in the negative direction (0). Such locations contain the most "interesting" summary of the input sequence.

    I've written a longer and detailed answer on how hidden states are constructed in PyTorch for recurrent modules.