I am trying to put a packed and padded sequence through a GRU, and retrieve the output of the last item of each sequence. Of course I don't mean the -1
item, but the actual last, not-padded item. We know the lengths of the sequences in advance, so it should be as easy as to extract for each sequence the length-1
item.
I tried the following
import torch
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
# Data
input = torch.Tensor([[[0., 0., 0.],
[1., 0., 1.],
[1., 1., 0.],
[1., 0., 1.],
[1., 0., 1.],
[1., 1., 0.]],
[[1., 1., 0.],
[0., 1., 0.],
[0., 0., 0.],
[0., 1., 0.],
[0., 0., 0.],
[0., 0., 0.]],
[[0., 0., 0.],
[1., 0., 0.],
[1., 1., 1.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]],
[[1., 1., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]]])
lengths = [6, 4, 3, 1]
p = pack_padded_sequence(input, lengths, batch_first=True)
# Forward
gru = torch.nn.GRU(3, 12, batch_first=True)
packed_output, gru_h = gru(p)
# Unpack
output, input_sizes = pad_packed_sequence(packed_output, batch_first=True)
last_seq_idxs = torch.LongTensor([x-1 for x in input_sizes])
last_seq_items = torch.index_select(output, 1, last_seq_idxs)
print(last_seq_items.size())
# torch.Size([4, 4, 12])
But the shape is not what I expect. I had expected to get 4x12
, i.e. last item of each individual sequence x hidden
.`
I could loop through the whole thing, and build a new tensor containing the items I need, but I was hoping for a built-in approach that took advantage of some smart math. I fear that manually looping and building, will result in very poor performance.
Instead of last two operations last_seq_idxs
and last_seq_items
you could just do last_seq_items=output[torch.arange(4), input_sizes-1]
.
I don't think index_select
is doing the right thing. It will select the whole batch at the index you passed and therefore your output size is [4,4,12].