Search code examples
machine-learningpytorchmini-batch

Masking and computing loss for a padded batch sent through an RNN with a linear output layer in pytorch


Although a typical use case, I can't find one simple and clear guide on what is the canonical way to compute loss on a padded minibatch in pytorch, when sent through an RNN.

I think a canonical pipeline could be:

  1. The pytorch RNN expects a padded batch tensor of shape: (max_seq_len, batch_size, emb_size)

  2. So we give an Embedding layer for example this tensor:

    tensor([[1, 1], [2, 2], [3, 9]])

9 is the padding index. Batch size is 2. The Embedding layer will make it to be of shape (max_seq_len, batch_size, emb_size). The sequences in the batch are in descending order, so we can pack it.

  1. We apply pack_padded_sequence, we apply the RNN, finally we apply pad_packed_sequence. We have at this point (max_seq_len, batch_size, hidden_size)

  2. Now we apply the linear output layer on the result and let's say the log_softmax. So at the end we have a tensor for a batch of scores of shape: (max_seq_len, batch_size, linear_out_size)

How should I compute the loss from here, masking out the padded part (with an arbitrary target)?


Solution

  • I think the PyTocrh Chatbot Tutorial might be instructional for you.

    Basically, you calculate the mask of valid output values (paddings are not valid), and use that to calculate the loss for only those values.

    See the outputVar and maskNLLLoss methods on the tutorial page. For your convenience I copied the code here, but you really need to see it in context of all the code.

    # Returns padded target sequence tensor, padding mask, and max target length
    def outputVar(l, voc):
        indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]
        max_target_len = max([len(indexes) for indexes in indexes_batch])
        padList = zeroPadding(indexes_batch)
        mask = binaryMatrix(padList)
        mask = torch.BoolTensor(mask)
        padVar = torch.LongTensor(padList)
        return padVar, mask, max_target_len
    
    def maskNLLLoss(inp, target, mask):
        nTotal = mask.sum()
        crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)).squeeze(1))
        loss = crossEntropy.masked_select(mask).mean()
        loss = loss.to(device)
        return loss, nTotal.item()