Could anyone please explain me the below code:
import torch
import torch.nn as nn
input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
c0 = torch.randn(2, 3, 20)
rnn = nn.LSTM(10,20,2)
output, (hn, cn) = rnn(input, (h0, c0))
print(input)
While calling rnn rnn(input, (h0, c0))
we gave arguments h0 and c0 in parenthesis. What is it supposed to mean? if (h0, c0) represents a single value then what is that value and what is the third argument passed here?
However, in the line rnn = nn.LSTM(10,20,2)
we are passing arguments in LSTM function without paranthesis.
Can anyone explain me how this function call is working?
The assignment rnn = nn.LSTM(10, 20, 2)
instanciates a new nn.Module
using the nn.LSTM
class. It's first three arguments are input_size
(here 10
), hidden_size
(here 20
) and num_layers
(here 2
).
On the other hand rnn(input, (h0, c0))
corresponds to actually calling the class instance, i.e.
running __call__
which is roughly equivalent to the forward
function of that module. The __call__
method of nn.LSTM
takes in two parameters: input
(shaped (sequnce_length, batch_size, input_size)
, and a tuple of two tensors (h_0, c_0)
(both shaped (num_layers, batch_size, hidden_size)
in the basic use case of nn.LSTM
)
Please refer to the PyTorch documentation whenever using builtins, you will find the exact definition of the parameters list (the arguments used to initialize the class instance) as well as the input/outputs specifications (whenever inferring with that said module).
You might be confused with the notation, here's a small example that could help:
tuple as input:
def fn1(x, p):
a, b = p # unpack input
return a*x + b
>>> fn1(2, (3, 1))
>>> 7
tuple as output
def fn2(x):
return x, (3*x, x**2) # actually output is a tuple of int and tuple
>>> x, (a, b) = fn2(2) # unpacking
(2, (6, 4))
>>> x, a, b
(2, 6, 4)