Search code examples
neural-networkpytorchlstmpaddingrecurrent-neural-network

How to manage the hidden state dims when using pad_sequence?


Using Pytorch LSTM architecture trying to build a text generation model. For every batch, I'm using pad_sequence to have min padding for every sequence, therefore I have a variable dims batch (batch_size * seq_len). I'm applying also pack_padded_seq to only give the non-zero (non-padding) tokens to the LSTM. But, the variable dims batch throws an error while feeding it to the LSTM as following; Expected hidden[0] size (1, 8, 16), got (1, 16, 16). In this error, I have provided batch size 16 with 8 tokens for every sequence, but the hidden state is 16 * 16.

I have tried to create the hidden state in the forward function, but that did not work well. How can I create the hidden state such that it can accept variable dims batch and it will not be lost for the whole epoche?

class RNNModule(nn.Module):
    def __init__(self, n_vocab, seq_size, embedding_size, lstm_size):
        super(RNNModule, self).__init__()
        self.seq_size = seq_size
        self.lstm_size = lstm_size
        self.embedding, num_embeddings, embedding_dim = create_emb_layer(weight_matrix, False)        
        self.lstm = nn.LSTM(embedding_size,
                        lstm_size,
                        num_layers=flags.n_layers,
                        batch_first=True
                        )
        self.dense = nn.Linear(lstm_size, n_vocab)

    def forward(self, x,length,prev_state):

        embed = self.embedding(x)
        packed_input = db.pack_src(embed,length)
        packed_output, state = self.lstm(packed_input,prev_state)
        padded,_ = db.pad_pack(packed_output) 
        logits = self.dense(padded)
        return logits, state


    def zero_state(self, batch_size = flags.batch_size):
        return (torch.zeros(flags.n_layers, batch_size, self.lstm_size),
            torch.zeros(flags.n_layers, batch_size, self.lstm_size))


input: tensor([[  19,    9,    4,    3,   68,    8,    6,    2],
    [  19,    9,    4,    3,    7,    8,    6,    2],
    [   3,   12,   17,   10,    6,   40,    2,    0],
    [   4,    3,  109,    7,    6,    2,    0,    0],
    [ 188,    6,    7,   18,    3,    2,    0,    0],
    [   4,    3,   12,    6,    7,    2,    0,    0],
    [   6,    7,    3,   13,    2,    0,    0,    0],
    [   3,   28,   17,   69,    2,    0,    0,    0],
    [   6,    3,   12,   11,    2,    0,    0,    0],
    [   3,   13,    6,    7,    2,    0,    0,    0],
    [   3,    6,    7,   13,    2,    0,    0,    0],
    [   6,    3,   23,    7,    2,    0,    0,    0],
    [   3,   28,   10,    2,    0,    0,    0,    0],
    [   6,    3,   23,    2,    0,    0,    0,    0],
    [   3,    6,   37,    2,    0,    0,    0,    0],
    [1218,    2,    0,    0,    0,    0,    0,    0]])

Zero tokens are padding. Embedding size: 64 LSTM size: 16 batch size: 16


Solution

  • The size of the hidden state you create has the correct size, but your input does not. When you pack it with nn.utils.rnn.pack_padded_sequence you've set batch_first=False, but your data has size [batch_size, seq_len, embedding_size] when you pass it to the packing, so that has batch_size as the first dimension. Also for the LSTM you use batch_first=True, which is appropriate for your data.

    You only need to pack it correctly by setting batch_first=True as well, to match the order of your data.

    rnn_utils.pack_padded_sequence(embed,length,batch_first=True)