Can someone give a full working code (not a snippet, but something that runs on a variable-length recurrent neural network) on how would you use the PackedSequence method in PyTorch?
There do not seem to be any examples of this in the documentation, github, or the internet.
Not the most beautiful piece of code, but this is what I gathered for my personal use after going through PyTorch forums and docs. There can be certainly better ways to handle the sorting - restoring part, but I chose it to be in the network itself
EDIT: See answer from @tusonggao which makes torch utils take care of sorting parts
class Encoder(nn.Module):
def __init__(self, vocab_size, embedding_size, embedding_vectors=None, tune_embeddings=True, use_gru=True,
hidden_size=128, num_layers=1, bidrectional=True, dropout=0.6):
super(Encoder, self).__init__()
self.embed = nn.Embedding(vocab_size, embedding_size, padding_idx=0)
self.embed.weight.requires_grad = tune_embeddings
if embedding_vectors is not None:
assert embedding_vectors.shape[0] == vocab_size and embedding_vectors.shape[1] == embedding_size
self.embed.weight = nn.Parameter(torch.FloatTensor(embedding_vectors))
cell = nn.GRU if use_gru else nn.LSTM
self.rnn = cell(input_size=embedding_size, hidden_size=hidden_size, num_layers=num_layers,
batch_first=True, bidirectional=True, dropout=dropout)
def forward(self, x, x_lengths):
sorted_seq_lens, original_ordering = torch.sort(torch.LongTensor(x_lengths), dim=0, descending=True)
ex = self.embed(x[original_ordering])
pack = torch.nn.utils.rnn.pack_padded_sequence(ex, sorted_seq_lens.tolist(), batch_first=True)
out, _ = self.rnn(pack)
unpacked, unpacked_len = torch.nn.utils.rnn.pad_packed_sequence(out, batch_first=True)
indices = Variable(torch.LongTensor(np.array(unpacked_len) - 1).view(-1, 1)
.expand(unpacked.size(0), unpacked.size(2))
.unsqueeze(1))
last_encoded_states = unpacked.gather(dim=1, index=indices).squeeze(dim=1)
scatter_indices = Variable(original_ordering.view(-1, 1).expand_as(last_encoded_states))
encoded_reordered = last_encoded_states.clone().scatter_(dim=0, index=scatter_indices, src=last_encoded_states)
return encoded_reordered