Search code examples
pythonnlppytorchhuggingface-transformerspytorch-lightning

How to get generated tokens in T5 training_step for using user-defined metrics?


I am fine-tuning T5 for question answering generation and want to add additional measures (e.g., BLEU, ROUGE) for the generated answers, in addition to the loss function.

For that, I believe it would be necessary to obtain the generated tokens (answers) at each training_step. However, after reading the source code, I still have no clue how to add that.

Below I leave an excerpt of my code. I can extract the output.loss and output.logits, but I didn't find a way to get the generated tokens to use additional evaluation metrics.

Thanks in advance.

class MyQAModel(pl.LightningModule):
  def __init__(self):
    super().__init__()
    self.model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, return_dict=True)

  def forward(self, input_ids, attention_mask, labels=None):
    output = self.model(
        input_ids, 
        attention_mask=attention_mask,
        labels=labels)

    return output.loss, output.logits

  def training_step(self, batch, batch_idx):
    input_ids = batch['input_ids']
    attention_mask=batch['attention_mask']
    labels = batch['labels']
    loss, outputs = self(input_ids, attention_mask, labels)
    self.log("train_loss", loss, prog_bar=True, logger=True)
    return {"loss": loss, "predictions":outputs, "labels": labels}
    
    ...
    (code continues...)
    ....

Solution

  • You can obtain predicted tokens from output.logits [batch, seq_len, vocab_size] using torch.argmax(output.logits, dim=-1) [batch, seq_len]. Then, to decode the generated sentence from a batch of token ids, run

    generated_sentences = []
    for predicted_token_ids in torch.argmax(output.logits, dim=-1):
        generated_sentences.append(tokenizer.decode(predicted_token_ids))
    
    # For getting original sentences
    original_sentences = []
    for sent_ids in input_ids:
        original_sentences.append(tokenizer.decode(sent_ids))