I'm trying to implement a neural network to generate sentences (image captions), and I'm using Pytorch's LSTM (nn.LSTM
) for that.
The input I want to feed in the training is from size batch_size * seq_size * embedding_size
, such that seq_size
is the maximal size of a sentence. For example - 64*30*512
.
After the LSTM there is one FC layer (nn.Linear
).
As far as I understand, this type of networks work with hidden state (h,c
in this case), and predict the next word each time.
My question is- in the training - do we have to manually feed the sentence word by word to the LSTM in the forward
function, or the LSTM knows how to do it itself?
My forward function looks like this:
def forward(self, features, caption, h = None, c = None):
batch_size = caption.size(0)
caption_size = caption.size(1)
no_hc = False
if h == None and c == None:
no_hc = True
h,c = self.init_hidden(batch_size)
embeddings = self.embedding(caption)
output = torch.empty((batch_size, caption_size, self.vocab_size)).to(device)
for i in range(caption_size): #go over the words in the sentence
if i==0:
lstm_input = features.unsqueeze(1)
else:
lstm_input = embeddings[:,i-1,:].unsqueeze(1)
out, (h,c) = self.lstm(lstm_input, (h,c))
out = self.fc(out)
output[:,i,:] = out.squeeze()
if no_hc:
return output
return output, h,c
(took inspiration from here)
The output of the forward
here is from size batch_size * seq_size * vocab_size
, which is good because it can be compared with the original batch_size * seq_size
sized caption in the loss function.
The question is whether this for
loop inside the forward
that feeds the words one after the other is really necessary, or I can somehow feed the entire sentence at once and get the same results?
(I saw some example that do that, for example this one, but I'm not sure if it's really equivalent)
The answer is, LSTM knows how to do it on its own. You do not have to manually feed each word one by one.
An intuitive way to understand is that the shape of the batch that you send, contains seq_length
(batch.shape[1]
), using which it decides the number of words in the sentence. The words are passed through LSTM Cell
generating the hidden states and C.