Search code examples
neural-networkpytorchlstmrecurrent-neural-network

Pytorch LSTM - generating sentence- word by word?


I'm trying to implement a neural network to generate sentences (image captions), and I'm using Pytorch's LSTM (nn.LSTM) for that.

The input I want to feed in the training is from size batch_size * seq_size * embedding_size, such that seq_size is the maximal size of a sentence. For example - 64*30*512.

After the LSTM there is one FC layer (nn.Linear). As far as I understand, this type of networks work with hidden state (h,c in this case), and predict the next word each time.

My question is- in the training - do we have to manually feed the sentence word by word to the LSTM in the forward function, or the LSTM knows how to do it itself?

My forward function looks like this:

    def forward(self, features, caption, h = None, c = None):
        batch_size = caption.size(0)
        caption_size = caption.size(1)
        
        no_hc = False
        if h == None and c == None:
            no_hc = True
            h,c = self.init_hidden(batch_size)
        
        embeddings = self.embedding(caption)  
        output = torch.empty((batch_size, caption_size, self.vocab_size)).to(device)

        for i in range(caption_size): #go over the words in the sentence
            if i==0:
                lstm_input = features.unsqueeze(1)
            else: 
                lstm_input = embeddings[:,i-1,:].unsqueeze(1)
            
            out, (h,c) = self.lstm(lstm_input, (h,c))
            out = self.fc(out)

            output[:,i,:] = out.squeeze()
        
        if no_hc:
            return output

        return output, h,c    

(took inspiration from here)

The output of the forward here is from size batch_size * seq_size * vocab_size, which is good because it can be compared with the original batch_size * seq_size sized caption in the loss function.

The question is whether this for loop inside the forward that feeds the words one after the other is really necessary, or I can somehow feed the entire sentence at once and get the same results?

(I saw some example that do that, for example this one, but I'm not sure if it's really equivalent)


Solution

  • The answer is, LSTM knows how to do it on its own. You do not have to manually feed each word one by one. An intuitive way to understand is that the shape of the batch that you send, contains seq_length (batch.shape[1]), using which it decides the number of words in the sentence. The words are passed through LSTM Cell generating the hidden states and C.