Search code examples
pytorchrecurrent-neural-networkencoder-decoder

Decoder always predicts the same token


I have the following decoder for machine translation that after a few steps only predicts the EOS token. Overfitting on a dummy, tiny dataset is impossible because of this so it seems that there is a big error in the code.

Decoder(
  (embedding): Embeddings(
    (word_embeddings): Embedding(30002, 768, padding_idx=3)
    (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (ffn1): FFN(
    (dense): Linear(in_features=768, out_features=512, bias=False)
    (layernorm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.5, inplace=False)
    (activation): GELU()
  )
  (rnn): GRU(512, 512, batch_first=True, bidirectional=True)
  (ffn2): FFN(
    (dense): Linear(in_features=1024, out_features=512, bias=False)
    (layernorm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.5, inplace=False)
    (activation): GELU()
  )
  (selector): Sequential(
    (0): Linear(in_features=512, out_features=30002, bias=True)
    (1): LogSoftmax(dim=-1)
  )
)

The forward is relatively straightforward (see what I did there?): pass the input_ids to the embedding and a FFN, then use that representation in the RNN with the given sembedding as initial hidden state. Pass the output through another FFN and do softmax. Return logits and last hidden states of the RNN. In the next step, use those hidden states as the new hidden states, and the highest predicted token as the new input.

def forward(self, input_ids, sembedding):
    embedded = self.embedding(input_ids)
    output = self.ffn1(embedded)
    output, hidden = self.rnn(output, sembedding)
    output = self.ffn2(output)
    logits = self.selector(output)

    return logits, hidden

sembedding is the initial hidden_state for the RNN. This is similar to an encoder-deocder architecture only here we do not train the encoder but we do have access to pretrained encoder representations.

In my training loop I start off each batch with a SOS token and feed every top predicted token to next step until target_len is reached. I also swap randomly between teacher forced training.

def step(self, batch, teacher_forcing_ratio=0.5):
    batch_size, target_len = batch["input_ids"].size()[:2]
    # Init first decoder input woth SOS (BOS) token
    decoder_input = torch.tensor([[self.tokenizer.bos_token_id]] * batch_size).to(self.device)
    batch["input_ids"] = batch["input_ids"].to(self.device)

    # Init first decoder hidden_state: one zero'd second embedding in case the RNN is bidirectional
    decoder_hidden = torch.stack((batch["sembedding"],
                                  torch.zeros(*batch["sembedding"].size()))
                                 ).to(self.device) if self.model.num_directions == 2 \
        else batch["sembedding"].unsqueeze(0).to(self.device)

    loss = torch.tensor([0.]).to(self.device)

    use_teacher_forcing = random.random() < teacher_forcing_ratio
    # contains tuples of predicted and correct words
    tokens = []
    for i in range(target_len):
        # overwrite previous decoder_hidden
        output, decoder_hidden = self.model(decoder_input, decoder_hidden)
        batch_correct_ids = batch["input_ids"][:, i]

        # NLLLoss compute loss between predicted classes (bs x classes) and correct classes for _this word_
        # set to ignore the padding index
        loss += self.criterion(output[:, 0, :], batch_correct_ids)

        batch_predicted_ids = output.topk(1).indices.squeeze(1).detach()

        # if use teacher training: use current correct word for next prediction
        # else do NOT use teacher training: us current predction for next prediction
        decoder_input = batch_correct_ids.unsqueeze(1) if use_teacher_forcing else batch_predicted_ids

    return loss, loss.item() / target_len

I also clip the gradients after each step:

clip_grad_norm_(self.model.parameters(), 1.0)

At first subsequent predictions are already relatively identical, but after a few iterations there's a bit more variation. But relatively quickly ALL predictions turn into other words (but always the same ones), eventually turning into EOS tokens (edit: after changing the activation to ReLU, another token is always predicted - it seems like a random token that always gets repeated). Note that this already happens after 80 steps (batch_size 128).

I found that the returned hidden state of the RNN contains a lot of zeros. I am not sure if that is the problem but it seems like it could be related.

tensor([[[  3.9874e-02,  -6.7757e-06,   2.6094e-04,  ...,  -1.2708e-17,
            4.1839e-02,   7.8125e-03],
         [ -7.8125e-03,  -2.5341e-02,   7.8125e-03,  ...,  -7.8125e-03,
           -7.8125e-03,  -7.8125e-03],
         [ -0.0000e+00, -1.0610e-314,   0.0000e+00,  ...,   0.0000e+00,
            0.0000e+00,   0.0000e+00],
         [  0.0000e+00,   0.0000e+00,   0.0000e+00,  ...,   0.0000e+00,
           -0.0000e+00,  1.0610e-314]]], device='cuda:0', dtype=torch.float64,
       grad_fn=<CudnnRnnBackward>)

I have no idea what might be going wrong although I suspect that the issue is rather with my step than with the model. I already tried playing with the learning rate, disabling some layers (LayerNorm, dropout, ffn2), using pretrained embeddings and freezing or unfreezing them, and disabling teacher forcing, using bidrectional vs unidirectional GRU. The end result is always the same.

If you have any pointers, that would be very helpful. I have googled many things concerning neural networks always predicting the same item and I have tried all the suggestions that I could find. Any new ones, no matter how crazy, are welcome!


Solution

  • In my case the issue appeared to be that the dtype of the initial hidden state was a double and the input was a float. I don't quite understand why that is an issue, but casting the hidden state to a float solved the issue. If you have any intuition about why this might be a problem for PyTorch, do let me know in the comments or, better yet, on the official PyTorch forums.

    EDIT: as that topic shows, this is a bug in PyTorch 1.6 that is solved in 1.7, In 1.7, you will get an error message which will hopefully save you the trouble of debugging all your code and not finding what causes strange behaviour.