Search code examples
pythonmachine-learningpytorchlstm

Longer LSTM Prediction


I am using an LSTM to take 5 sequences as input to predict another 5. I want to know how to predict more than 5 timesteps. I assume its got something to do with the hidden_dim but I can't figure it out.

Here is my code

class LSTM(nn.Module):
    def __init__(self, seq_len=5, n_features=256, n_hidden=256, n_layers=1, output_size=1):
        super().__init__()
        self.n_features = n_features
        self.seq_len = seq_len
        self.n_hidden = n_hidden
        self.n_layers = n_layers
        
        self.l_lstm = nn.LSTM(input_size=self.n_features, hidden_size=self.n_hidden, num_layers=self.n_layers, batch_first=True)

        

    def init_hidden(self, batch_size):
        hidden_state = torch.zeros(self.n_layers,batch_size,self.n_hidden).to(device)
        cell_state = torch.zeros(self.n_layers,batch_size,self.n_hidden).to(device)
        self.hidden = (hidden_state, cell_state)
    
    def forward(self, x):        
        
        lstm_out, self.hidden = self.l_lstm(x,self.hidden)

        return lstm_out

If anyone knows how to extend the prediction range or could suggest a better way of writing LSTM, I would really appreciate it.


Solution

  • Right now you are running your LSTM forward for 5 timesteps, and returning the hidden state that resulted at each timestep. This kind of approach is commonly used when you know you need one output for every input, e.g. in sequence labeling problems (e.g. tagging each word in a sentence with its part of speech).

    If you want to encode a variable length sequence, then decode a sequence of arbitrary, possibly different length (e.g. for machine translation), you need to look up sequence-to-sequence (seq2seq) modeling more generally. This is a bit more involved and involves two LSTMs, one for encoding the input sequence, the other for decoding the output sequence (see EncoderRNN and DecoderRNN implementations in the pytorch tutorial linked above).

    The basic idea is to take e.g. the final state of the LSTM after consuming the input sentence, then use that state to initialize a separate LSTM decoder, from which you sample autoregressively - in other words, you generate a new token, feed the token back into the decoder, then continue either for an arbitrary number of steps that you specify, or until the LSTM samples an "end of sentence" token, if you've trained the LSTM to predict the end of sampled sequences.