Search code examples
pythonpytorchhuggingface-transformershuggingface-tokenizers

T5 model generates short output


I have fine-tuned the T5-base model (from hugging face) on a new task where each input and target are sentences of 256 words. The loss is converging to low values however when I use the generate method the output is always too short. I tried giving minimal and maximal length values to the method but it doesn't seem to be enough. I suspect the issue is related to the fact that the sentence length before tokenization is 256 and after tokenization, it is not constant (padding is used during training to ensure all inputs are of the same size). Here is my generate method:

model = transformers.T5ForConditionalGeneration.from_pretrained('t5-base')
tokenizer = T5Tokenizer.from_pretrained('t5-base')
generated_ids = model.generate(
input_ids=ids,
attention_mask=attn_mask,
max_length=1024,
min_length=256,
num_beams=2,
early_stopping=False,
repetition_penalty=10.0

)
preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids][0]
preds = preds.replace("<pad>", "").replace("</s>", "").strip().replace("  ", " ")
target = [tokenizer.decode(t, skip_special_tokens=True, clean_up_tokenization_spaces=True) for t in reference][0]
target = target.replace("<pad>", "").replace("</s>", "").strip().replace("  ", " ")

The inputs are created using

tokens = tokenizer([f"task: {text}"], return_tensors="pt", max_length=1024, padding='max_length')
inputs_ids = tokens.input_ids.squeeze().to(dtype=torch.long)
attention_mask = tokens.attention_mask.squeeze().to(dtype=torch.long)
labels = self.tokenizer([target_text], return_tensors="pt", max_length=1024, padding='max_length')
label_ids = labels.input_ids.squeeze().to(dtype=torch.long)
label_attention = labels.attention_mask.squeeze().to(dtype=torch.long)

Solution

  • For whom it may concern, I found out the issue was with the max_length argument of the generation method. It limits the maximal number of tokens including the input tokens. In my case it was required to set max_new_tokens=1024 instead of the argument provided in the question.