Search code examples
pytorchnlpartificial-intelligencehuggingface-transformerssummarization

How to properly prompt the decoder of a Transformer model?


I am using Hugging Face Transformers. I have a pretrained Encoder + Decoder model (Pegasus), and want to fine-tune it as described in this article.

Specifically, they use the following process:

Summary generation using entity prompts

In other words, they prepend a manual prompt to the generation of the model itself.

My question relates to the Decoder input. Specifically, I want to fine tune the model so that it takes the prompt (entity chain), and generates a summary from that point onwards.

For instance:

<s> [ENTITYCHAIN] Frozen | Disney [SUMMARY] $tok_1 $tok_2 $tok_3 ...
=========================================== ^^^^^^ ^^^^^^ ^^^^^^
This is not generated                       Generate from here

However, as you would expect, the model is generating predictions for each token in the entity chain, which I do not need. But most importantly, the loss is being computed by also factoring in the predictions related to the entity chain. This clearly undermines the purpose of training, since it confuses the model, because it should learn to only generate the summary, and not the entity chain (which is already given as a prompt).

As I was saying, what I want is to give a prompt (entity chain) to the decoder, and make it generate a summary, while being able to attend to the extra information from the prompt. Of course, the loss should only be computed among the generated tokens, excluding the prompt tokens.

By looking into the model documentation, I don't seem to find an option to do this. Any ideas? :)


Solution

  • A convention that pytorch loss functions use is that if you set a label to -100 during training, the loss function will ignore the token. See the Documentation for ease of mind.

    Here's a minimal code example:

    # Libraries
    import transformers
    from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
    from copy import deepcopy
    
    # Get the tokenizer and the model
    checkpoint = 't5-small'
    tokenizer = AutoTokenizer.from_pretrained(checkpoint)
    model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
    
    # Sample text
    inp = 'Here is my input'
    outp = 'Here is my output'
    
    # Get token IDs
    inp_ids = tokenizer(inp, return_tensors = 'pt').input_ids
    outp_ids = tokenizer(outp, return_tensors = 'pt').input_ids
    
    # Calculate loss
    loss = model(input_ids = inp_ids, labels = outp_ids).loss.item()
    
    print(loss)
    
    # Let's set the first token to -100 and recalculate loss
    modified_outp_ids = deepcopy(outp_ids)
    modified_outp_ids[0][0] = -100 # the first [0] is because we only have one sequence in our batch
    
    model_output = model(input_ids = inp_ids, labels = modified_outp_ids)
    
    print(model_output.loss.item())