Search code examples
coding-styleopen-sourcelstmimplementationpytorch

Does a clean and extendable LSTM implementation exists in PyTorch?


I would like to create an LSTM class by myself, however, I don't want to rewrite the classic LSTM functions from scratch again.

Digging in the code of PyTorch, I only find a dirty implementation involving at least 3-4 classes with inheritance:

  1. https://github.com/pytorch/pytorch/blob/98c24fae6b6400a7d1e13610b20aa05f86f77070/torch/nn/modules/rnn.py#L323
  2. https://github.com/pytorch/pytorch/blob/98c24fae6b6400a7d1e13610b20aa05f86f77070/torch/nn/modules/rnn.py#L12
  3. https://github.com/pytorch/pytorch/blob/98c24fae6b6400a7d1e13610b20aa05f86f77070/torch/nn/_functions/rnn.py#L297

Does a clean PyTorch implementation of an LSTM exist somewhere? Any links would help.

For example, I know that clean implementations of a LSTM exists in TensorFlow, but I would need to derive a PyTorch one.

For a clear example, what I'm searching for is an implementation as clean as this, but in PyTorch:


Solution

  • The best implementation I found is here
    https://github.com/pytorch/benchmark/blob/master/rnns/benchmarks/lstm_variants/lstm.py

    It even implements four different variants of recurrent dropout, which is very useful!
    If you take the dropout parts away you get

    import math
    import torch as th
    import torch.nn as nn
    
    class LSTM(nn.Module):
    
        def __init__(self, input_size, hidden_size, bias=True):
            super(LSTM, self).__init__()
            self.input_size = input_size
            self.hidden_size = hidden_size
            self.bias = bias
            self.i2h = nn.Linear(input_size, 4 * hidden_size, bias=bias)
            self.h2h = nn.Linear(hidden_size, 4 * hidden_size, bias=bias)
            self.reset_parameters()
    
        def reset_parameters(self):
            std = 1.0 / math.sqrt(self.hidden_size)
            for w in self.parameters():
                w.data.uniform_(-std, std)
    
        def forward(self, x, hidden):
            h, c = hidden
            h = h.view(h.size(1), -1)
            c = c.view(c.size(1), -1)
            x = x.view(x.size(1), -1)
    
            # Linear mappings
            preact = self.i2h(x) + self.h2h(h)
    
            # activations
            gates = preact[:, :3 * self.hidden_size].sigmoid()
            g_t = preact[:, 3 * self.hidden_size:].tanh()
            i_t = gates[:, :self.hidden_size]
            f_t = gates[:, self.hidden_size:2 * self.hidden_size]
            o_t = gates[:, -self.hidden_size:]
    
            c_t = th.mul(c, f_t) + th.mul(i_t, g_t)
    
            h_t = th.mul(o_t, c_t.tanh())
    
            h_t = h_t.view(1, h_t.size(0), -1)
            c_t = c_t.view(1, c_t.size(0), -1)
            return h_t, (h_t, c_t)
    

    PS: The repository contains many more variants of LSTM and other RNNs:
    https://github.com/pytorch/benchmark/tree/master/rnns/benchmarks.
    Check it out, maybe the extension you had in mind is already there!

    EDIT:
    As mentioned in the comments, you can wrap the LSTM cell above to process sequential output:

    import math
    import torch as th
    import torch.nn as nn
    
    
    class LSTMCell(nn.Module):
    
        def __init__(self, input_size, hidden_size, bias=True):
            # As before
    
        def reset_parameters(self):
            # As before
    
        def forward(self, x, hidden):
    
            if hidden is None:
                hidden = self._init_hidden(x)
    
            # Rest as before
    
        @staticmethod
        def _init_hidden(input_):
            h = th.zeros_like(input_.view(1, input_.size(1), -1))
            c = th.zeros_like(input_.view(1, input_.size(1), -1))
            return h, c
    
    
    class LSTM(nn.Module):
    
        def __init__(self, input_size, hidden_size, bias=True):
            super().__init__()
            self.lstm_cell = LSTMCell(input_size, hidden_size, bias)
    
        def forward(self, input_, hidden=None):
            # input_ is of dimensionalty (1, time, input_size, ...)
    
            outputs = []
            for x in torch.unbind(input_, dim=1):
                hidden = self.lstm_cell(x, hidden)
                outputs.append(hidden[0].clone())
    
            return torch.stack(outputs, dim=1)
    

    I havn't tested the code since I'm working with a convLSTM implementation. Please let me know if something is wrong.

    UPDATE: Fixed links.