Search code examples
machine-learningdeep-learningpytorchlstmrecurrent-neural-network

arguments and function call of LSTM in pytorch


Could anyone please explain me the below code:

import torch
import torch.nn as nn

input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
c0 = torch.randn(2, 3, 20)

rnn = nn.LSTM(10,20,2)

output, (hn, cn) = rnn(input, (h0, c0))
print(input)

While calling rnn rnn(input, (h0, c0)) we gave arguments h0 and c0 in parenthesis. What is it supposed to mean? if (h0, c0) represents a single value then what is that value and what is the third argument passed here? However, in the line rnn = nn.LSTM(10,20,2) we are passing arguments in LSTM function without paranthesis. Can anyone explain me how this function call is working?


Solution

  • The assignment rnn = nn.LSTM(10, 20, 2) instanciates a new nn.Module using the nn.LSTM class. It's first three arguments are input_size (here 10), hidden_size (here 20) and num_layers (here 2).

    On the other hand rnn(input, (h0, c0)) corresponds to actually calling the class instance, i.e. running __call__ which is roughly equivalent to the forward function of that module. The __call__ method of nn.LSTM takes in two parameters: input (shaped (sequnce_length, batch_size, input_size), and a tuple of two tensors (h_0, c_0) (both shaped (num_layers, batch_size, hidden_size) in the basic use case of nn.LSTM)

    Please refer to the PyTorch documentation whenever using builtins, you will find the exact definition of the parameters list (the arguments used to initialize the class instance) as well as the input/outputs specifications (whenever inferring with that said module).


    You might be confused with the notation, here's a small example that could help:

    • tuple as input:

      def fn1(x, p):
          a, b = p # unpack input
          return a*x + b
      
      >>> fn1(2, (3, 1))
      >>> 7
      
    • tuple as output

      def fn2(x):
          return x, (3*x, x**2) # actually output is a tuple of int and tuple 
      
      >>> x, (a, b) = fn2(2) # unpacking
      (2, (6, 4))
      
      >>> x, a, b
      (2, 6, 4)