Search code examples
pytorchlstm

Need clear concept of the dimensions of output and hidden from LSTM layers


I know that the output carries all hiddens from the last layer of all the time steps and the hidden is the last time step hiddens of all the layers. This context has each document with 850 tokens. Each token is embedded into 100 dimension. I took a 2-layer LSTM with 100 dim hidden.

I thought it would take a token at a time step and produce 100 dim hidden. For 850 tokens in a document it will produce output = [1, 850, 100], hidden [1, 2, 100] and cell [1, 2, 100]. But the hidden and cell are [2, 850, 100].

input_dim = len(tok2indx)   # size of the vocabulary
emb_dim = 100               # Embedding of each word
hid_dim = 100               # The dimention of each hiddenstate comming out from a time step
n_layers = 2                # LSTM layers 
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.hid_dim = hid_dim
        self.n_layers = n_layers
        
        self.embedding = nn.Embedding(input_dim, emb_dim)     
        self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout = dropout, device=device)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, X):
        
        embedded = self.embedding(X).to(device)
        outputs, (hidden, cell) = self.rnn(embedded)
              
        return outputs, hidden, cell

If the encoder is passed a single document

enc = Encoder()
encd = enc.forward(train_x[:1])
print(encd[0].shape, encd[1].shape, encd[2].shape)

Output:

torch.Size([1, 850, 100]) torch.Size([2, 850, 100]) torch.Size([2, 850, 100])

With ten documents

encd = enc.forward(train_x[:10])
print(encd[0].shape, encd[1].shape, encd[2].shape)

Output:

torch.Size([10, 850, 100]) torch.Size([2, 850, 100]) torch.Size([2, 850, 100])

Solution

  • What's tripping you up is the input format to LSTM. The default input shape to a LSTM layer is Sequence (L), batch (N), features (H). While in you code you are sending input as NLH (batch, sequence, features). To use this correctly set the parameter batch_first=True (to the LSTM layer), then the input and output will be as you expect.

    But there is a catch here too. Only the output (1st of the outputs) will be NLH while both hidden and cell (2nd and 3rd of the outputs) will still be LNH format.

    The second thing to note here is the hidden cell will have dimensionality equal to the number of layers ie 2 in your example (each layer will require fill of its own hidden weights), hence the output [2, 850, 100] instead of [1, 850, 100].