I'm a beginner in PyTorch. From the lstm description, I learned that I can create a stacked lstm with 3 layers by:
layer = torch.nn.LSTM(128, 512, num_layers=3)
Then in the forward
function, I can do:
def forward(x, state):
x, state = layer(x, state)
return x, (state[0].detach(), state[1].detach())
And I can pass state
from batch to batch.
But if I create 3 lstm layers, what is the equivalent to that if I want to implement the same stacked layers myself?
layer1 = torch.nn.LSTM(128, 512, num_layers=1)
layer2 = torch.nn.LSTM(128, 512, num_layers=1)
layer3 = torch.nn.LSTM(128, 512, num_layers=1)
In this case, what should go into the forward
function and get the returned state
?
I also tried to look at the source code of pytorch lstm, but in the forward
function it calls a _VF
module which I cannot find where it is defined.
If you define state
as a list of the 3 layers' states, then
def forward(x, state):
x, s0 = layer1(x, state[0])
x, s1 = layer2(x, state[1])
x, s2 = layer3(x, state[2])
return x, [s0.detach(), s1.detach(), s2.detach()]