Search code examples
pythonnlphuggingface-transformersgenerative-pretrained-transformer

How to prevent transformer generate function to produce certain words?


I have the following code:

from transformers import T5Tokenizer, T5ForConditionalGeneration

tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5ForConditionalGeneration.from_pretrained("t5-small")

input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids

sequence_ids = model.generate(input_ids)
sequences = tokenizer.batch_decode(sequence_ids)
sequences

Currently it produces this:

['<pad><extra_id_0> park offers<extra_id_1> the<extra_id_2> park.</s>']

Is there a way to prevent the generator to produce certain words (e.g. stopwords = ["park", "offer"])?


Solution

  • after looking at the docs found out there is a bad_words_ids parameter that you can pass in the generate()

    given a bad word list you can create the id list using

    tokenizer(bad_words, add_special_tokens=False).input_ids
    
    input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
    bad_words = ["park", "offers"]
    bad_words_ids = tokenizer(bad_words, add_special_tokens=False).input_ids 
    #[[2447], [704]]
    
    sequence_ids = model.generate(input_ids, bad_words_ids=bad_words_ids)
    #tensor([[    0, 32099,  1061,    19,     3,     9,   710,  1482,   550,    45, 32098,     8, 32097,  1061,     5,     1]])
    
    sequences = tokenizer.batch_decode(sequence_ids)
    print(sequences) 
    #['<pad><extra_id_0> Park is a short walk away from<extra_id_1> the<extra_id_2> Park.</s>']
    
    

    Notice how the word "Park" is appearing now. This is because the tokenizer identifies park (id 2447) and Park (id 1061) as 2 different tokens. This may depend on the tokenizer you use (there are case-insensitive tokenizers). If you don't want this to happen you can add Park into the bad word list as well.

    Colab demo