Search code examples
pythontorchpytorch

How to use pack_padded_sequence with multiple variable-length input with the same label in pytorch


I have a model which takes three variable-length inputs with the same label. Is there a way I could use pack_padded_sequence somehow? If so, how should I sort my sequences?

For example,

a = (([0,1,2], [3,4], [5,6,7,8]), 1) # training data is in length 3,2,4; label is 1
b = (([0,1], [2], [6,7,8,9,10]), 1)

Both a and b will be fed into three separated LSTMs and the result will be merged to predict the target.


Solution

  • Let's do it step by step.

    Input Data Processing

    a = (([0,1,2], [3,4], [5,6,7,8]), 1)
    
    # store length of each element in an array
    len_a = np.array([len(a) for a in a[0]]) 
    variable_a  = np.zeros((len(len_a), np.amax(len_a)))
    for i, a in enumerate(a[0]):
        variable_a[i, 0:len(a)] = a
    
    vocab_size = len(np.unique(variable_a))
    Variable(torch.from_numpy(variable_a).long())
    print(variable_a)
    

    It prints:

    Variable containing:
     0  1  2  0
     3  4  0  0
     5  6  7  8
    [torch.DoubleTensor of size 3x4]
    

    Defining embedding and RNN layer

    Now, let's say, we have an Embedding and RNN layer class as follows.

    class EmbeddingLayer(nn.Module):
    
        def __init__(self, input_size, emsize):
            super(EmbeddingLayer, self).__init__()
            self.embedding = nn.Embedding(input_size, emsize)
    
        def forward(self, input_variable):
            return self.embedding(input_variable)
    
    
    class Encoder(nn.Module):
    
        def __init__(self, input_size, hidden_size, bidirection):
            super(Encoder, self).__init__()
            self.input_size = input_size
            self.hidden_size = hidden_size
            self.bidirection = bidirection
            self.rnn = nn.LSTM(self.input_size, self.hidden_size, batch_first=True, 
                                        bidirectional=self.bidirection)
    
        def forward(self, sent_variable, sent_len):
            # Sort by length (keep idx)
            sent_len, idx_sort = np.sort(sent_len)[::-1], np.argsort(-sent_len)
            idx_unsort = np.argsort(idx_sort)
    
            idx_sort = torch.from_numpy(idx_sort)
            sent_variable = sent_variable.index_select(0, Variable(idx_sort))
    
            # Handling padding in Recurrent Networks
            sent_packed = nn.utils.rnn.pack_padded_sequence(sent_variable, sent_len, batch_first=True)
            sent_output = self.rnn(sent_packed)[0]
            sent_output = nn.utils.rnn.pad_packed_sequence(sent_output, batch_first=True)[0]
    
            # Un-sort by length
            idx_unsort = torch.from_numpy(idx_unsort)
            sent_output = sent_output.index_select(0, Variable(idx_unsort))
    
            return sent_output
    

    Embed and encode the processed input data

    We can embed and encode our input as follows.

    emb = EmbeddingLayer(vocab_size, 50)
    enc = Encoder(50, 100, False, 'LSTM')
    
    emb_a = emb(variable_a)
    enc_a = enc(emb_a, len_a)
    

    If you print the size of enc_a, you will get torch.Size([3, 4, 100]). I hope you understand the meaning of this shape.

    Please note, the above code runs only on CPU.