Search code examples
pythonmachine-learningpytorchrecurrent-neural-networklanguage-model

In PyTorch, what's the difference between training an RNN to predict the last word given a sequence, vs predicting the entire sequence shifted?


Let's say I'm trying to train an RNN language model in PyTorch. Suppose I iterate over batches of word sequences, and that each training batch tensor has the following shape:

data.shape = [batch_size, sequence_length, vocab_dim]

My question is, what's the difference between using only the last word in each sequence as the target label:

X = data[:,:-1]
y = data[:,-1]

and training to minimize loss using a softmax prediction of the last word,

vs setting the target to be the entire sequence shifted right:

X = data[:,:-1]
y = data[:,1:]

and training to minimize the sum of losses of each predicted word in the shifted sequence?

What's the correct approach here? I feel like i've seen both examples online. Does this also have to do with loop unrolling vs BPTT?


Solution

  • Consider the sequence prediction problem a b c d where you want to train an RNN via teacher forcing.

    If you only use the last word in the sentence, you are doing the following classification problem (on the left is the input; on the right is the output you're supposed to predict):

    a b c -> d

    For your second approach, where y is set to be the entire sequence shifted right, you are doing three classification problems:

    a -> b
    a b -> c
    a b c -> d
    

    The task of predicting the intermediate words in a sequence is crucial for training a useful RNN (otherwise, you would know how to get from c given a b, but you wouldn't know how to proceed after just a).

    An equivalent thing to do would be to do define your training data as both the complete sequence a b c d and all incomplete sequences (a b, a b c). Then if you were to do just the "last word" prediction as mentioned previously, you would end up with the same supervision as the formulation where y is the entire sequence shifted right. But this is computationally wasteful - you don't want to rerun the RNN on both a b and a b c (the state you get from a b can be reused to obtain the state after consuming a b c).

    In other words, the point of doing the "shift y right" is to split a single sequence (a b c d) of length N into N - 1 independent classification problems of the form: "given words up to time t, predict word t + 1", while needing just one RNN forward pass.