Search code examples
pytorchlstmrecurrent-neural-network

Understanding the architecture of an LSTM for sequence classification


I have this model in pytorch that I have been using for sequence classification.

class RoBERT_Model(nn.Module):

    def __init__(self, hidden_size = 100):
        self.hidden_size = hidden_size
        super(RoBERT_Model, self).__init__()
        self.lstm = nn.LSTM(768, hidden_size, num_layers=1, bidirectional=False)
        self.out = nn.Linear(hidden_size, 2)

    def forward(self, grouped_pooled_outs):
        # chunks_emb = pooled_out.split_with_sizes(lengt) # splits the input tensor into a list of tensors where the length of each sublist is determined by length

        seq_lengths = torch.LongTensor([x for x in map(len, grouped_pooled_outs)]) # gets the length of each sublist in chunks_emb and returns it as an array

        batch_emb_pad = nn.utils.rnn.pad_sequence(grouped_pooled_outs, padding_value=-91, batch_first=True) # pads each sublist in chunks_emb to the largest sublist with value -91
        batch_emb = batch_emb_pad.transpose(0, 1)  # (B,L,D) -> (L,B,D)
        lstm_input = nn.utils.rnn.pack_padded_sequence(batch_emb, seq_lengths, batch_first=False, enforce_sorted=False) # seq_lengths.cpu().numpy()

        packed_output, (h_t, h_c) = self.lstm(lstm_input, )  # (h_t, h_c))
        # output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, padding_value=-91)
        h_t = h_t.view(-1, self.hidden_size) # (-1, 100)

        return self.out(h_t) # logits

The issue that I am having is that I am not entirely convinced of what data is being passed to the final classification layer. I believe what is being done is that only the final LSTM cell in the last layer is being used for classification. That is there are hidden_size features that are passed to the feedforward layer.

I have depicted what I believe is going on in this figure here:

enter image description here

Is this understanding correct? Am I missing anything?

Thanks.


Solution

  • Your code is a basic LSTM for classification, working with a single rnn layer.

    In your picture you have multiple LSTM layers, while, in reality, there is only one, H_n^0 in the picture.

    1. Your input to LSTM is of shape (B, L, D) as correctly pointed out in the comment.
    2. packed_output and h_c is not used at all, hence you can change this line to: _, (h_t, _) = self.lstm(lstm_input) in order no to clutter the picture further
    3. h_t is output of last step for each batch element, in general (B, D * L, hidden_size). As this neural network is not bidirectional D=1, as you have a single layer L=1 as well, hence the output is of shape (B, 1, hidden_size).
    4. This output is reshaped into nn.Linear compatible (this line: h_t = h_t.view(-1, self.hidden_size)) and will give you output of shape (B, hidden_size)
    5. This input is fed to a single nn.Linear layer.

    In general, the output of the last time step from RNN is used for each element in the batch, in your picture H_n^0 and simply fed to the classifier.

    By the way, having self.out = nn.Linear(hidden_size, 2) in classification is probably counter-productive; most likely your are performing binary classification and self.out = nn.Linear(hidden_size, 1) with torch.nn.BCEWithLogitsLoss might be used. Single logit contains information whether the label should be 0 or 1; everything smaller than 0 is more likely to be 0 according to nn, everything above 0 is considered as a 1 label.