In the Stack Overflow thread How can i add a Bi-LSTM layer on top of bert model?, there is a line of code:
hidden = torch.cat((lstm_output[:,-1, :256],lstm_output[:,0, 256:]),dim=-1)
Can someone explain why the concatenation of last and first tokens and not any other? What would these two tokens contain that they were chosen?
In bidirectional models, hidden states gets concatenated at each step; so, the line basically concatenates the first :256
units of the last hidden state in the positive direction (-1
) to the last 256:
units of the last hidden state in the negative direction (0
). Such locations contain the most "interesting" summary of the input sequence.
I've written a longer and detailed answer on how hidden states are constructed in PyTorch for recurrent modules.