Search code examples
pythonpytorchlstm

What is equivalent to pytorch lstm num_layers?


I'm a beginner in PyTorch. From the lstm description, I learned that I can create a stacked lstm with 3 layers by:

layer = torch.nn.LSTM(128, 512, num_layers=3)

Then in the forward function, I can do:

def forward(x, state):
    x, state = layer(x, state)
    return x, (state[0].detach(), state[1].detach())

And I can pass state from batch to batch.
But if I create 3 lstm layers, what is the equivalent to that if I want to implement the same stacked layers myself?

layer1 = torch.nn.LSTM(128, 512, num_layers=1)
layer2 = torch.nn.LSTM(128, 512, num_layers=1)
layer3 = torch.nn.LSTM(128, 512, num_layers=1)

In this case, what should go into the forward function and get the returned state?
I also tried to look at the source code of pytorch lstm, but in the forward function it calls a _VF module which I cannot find where it is defined.


Solution

  • If you define state as a list of the 3 layers' states, then

    def forward(x, state):
        x, s0 = layer1(x, state[0])
        x, s1 = layer2(x, state[1])
        x, s2 = layer3(x, state[2])
        return x, [s0.detach(), s1.detach(), s2.detach()]