I'm trying to understand the code of Transformer (https://github.com/SamLynnEvans/Transformer).
If seeing the train_model function in "train" script, I wonder why need to use the different sequence length of trg_input from trg:
trg_input = trg[:, :-1]
In this case, the sequence length of trg_input is "seq_len(trg) - 1". It means that trg is like:
<sos> tok1 tok2 tokn <eos>
and trg_input is like:
<sos> tok1 tok2 tokn (no eos token)
Please let me know the reason.
Thank you.
The related code is like below:
for i, batch in enumerate(opt.train):
src = batch.src.transpose(0, 1).to('cuda')
trg = batch.trg.transpose(0, 1).to('cuda')
trg_input = trg[:, :-1]
src_mask, trg_mask = create_masks(src, trg_input, opt)
preds = model(src, trg_input, src_mask, trg_mask)
ys = trg[:, 1:].contiguous().view(-1)
loss = F.cross_entropy(preds.view(-1, preds.size(-1)), ys, ignore_index=opt.trg_pad)
def create_masks(src, trg, opt):
src_mask = (src != opt.src_pad).unsqueeze(-2)
if trg is not None:
trg_mask = (trg != opt.trg_pad).unsqueeze(-2)
size = trg.size(1) # get seq_len for matrix
np_mask = nopeak_mask(size, opt)
if trg.is_cuda:
trg_mask = trg_mask & np_mask
trg_mask = None
return src_mask, trg_mask
That's because the entire aim is to generate the next token based on the tokens we've seen so far. Take a look at the input into the model when we get our predictions. We're not just feeding the source sequence, but also the target sequence up until our current step. The model inside Models.py
looks like:
class Transformer(nn.Module):
def __init__(self, src_vocab, trg_vocab, d_model, N, heads, dropout):
self.encoder = Encoder(src_vocab, d_model, N, heads, dropout)
self.decoder = Decoder(trg_vocab, d_model, N, heads, dropout)
self.out = nn.Linear(d_model, trg_vocab)
def forward(self, src, trg, src_mask, trg_mask):
e_outputs = self.encoder(src, src_mask)
d_output = self.decoder(trg, e_outputs, src_mask, trg_mask)
output = self.out(d_output)
return output
So you can see that the forward
method receives src
and trg
, which are each fed into the encoder and decoder. This is a bit easier to grasp if you take a look at the model architecture from the original paper:
The "Outputs (shifted right)" corresponds to trg[:, :-1]
in the code.