Search code examples
machine-learningtorchrecurrent-neural-networkpytorch

How do you use PyTorch PackedSequence in code?


Can someone give a full working code (not a snippet, but something that runs on a variable-length recurrent neural network) on how would you use the PackedSequence method in PyTorch?

There do not seem to be any examples of this in the documentation, github, or the internet.

https://github.com/pytorch/pytorch/releases/tag/v0.1.10


Solution

  • Not the most beautiful piece of code, but this is what I gathered for my personal use after going through PyTorch forums and docs. There can be certainly better ways to handle the sorting - restoring part, but I chose it to be in the network itself

    EDIT: See answer from @tusonggao which makes torch utils take care of sorting parts

    class Encoder(nn.Module):
        def __init__(self, vocab_size, embedding_size, embedding_vectors=None, tune_embeddings=True, use_gru=True,
                     hidden_size=128, num_layers=1, bidrectional=True, dropout=0.6):
            super(Encoder, self).__init__()
            self.embed = nn.Embedding(vocab_size, embedding_size, padding_idx=0)
            self.embed.weight.requires_grad = tune_embeddings
            if embedding_vectors is not None:
                assert embedding_vectors.shape[0] == vocab_size and embedding_vectors.shape[1] == embedding_size
                self.embed.weight = nn.Parameter(torch.FloatTensor(embedding_vectors))
            cell = nn.GRU if use_gru else nn.LSTM
            self.rnn = cell(input_size=embedding_size, hidden_size=hidden_size, num_layers=num_layers,
                            batch_first=True, bidirectional=True, dropout=dropout)
    
        def forward(self, x, x_lengths):
            sorted_seq_lens, original_ordering = torch.sort(torch.LongTensor(x_lengths), dim=0, descending=True)
            ex = self.embed(x[original_ordering])
            pack = torch.nn.utils.rnn.pack_padded_sequence(ex, sorted_seq_lens.tolist(), batch_first=True)
            out, _ = self.rnn(pack)
            unpacked, unpacked_len = torch.nn.utils.rnn.pad_packed_sequence(out, batch_first=True)
            indices = Variable(torch.LongTensor(np.array(unpacked_len) - 1).view(-1, 1)
                                                                           .expand(unpacked.size(0), unpacked.size(2))
                                                                           .unsqueeze(1))
            last_encoded_states = unpacked.gather(dim=1, index=indices).squeeze(dim=1)
            scatter_indices = Variable(original_ordering.view(-1, 1).expand_as(last_encoded_states))
            encoded_reordered = last_encoded_states.clone().scatter_(dim=0, index=scatter_indices, src=last_encoded_states)
            return encoded_reordered