Search code examples
python-3.xnlppytorchhuggingface-transformers

Whitelist tokens for text generation (XLNet, GPT-2) in huggingface-transformers


In the documentation on text generation (https://huggingface.co/transformers/main_classes/model.html#generative-models) there is the option to put

bad_words_ids (List[int], optional) – List of token ids that are not allowed to be generated. In order to get the tokens of the words that should not appear in the generated text, use tokenizer.encode(bad_word, add_prefix_space=True).

Is there also the option to put something along the lines of "allowed_words_ids"? The idea would be to restrict the language of the generated texts.


Solution

  • I'd also suggest to do what Sahar Mills said. You can do it in the following way.

    1. You get the whole vocab of the model you are using, e.g.
    from transformers import AutoTokenizer
    
    # Load tokenizer
    checkpoint = "CenIA/distillbert-base-spanish-uncased" #Example model
    tokenizer = AutoTokenizer.from_pretrained(checkpoint)
    
    vocab = tokenizer.get_vocab()
    list(vocab.keys())[:100] # to see the first 100 words
    
    1. Define words you do want in the model.
    words_to_delete = ['forzado', 'vendieron', 'verticales'] # or load them from somewhere else
    
    1. Define function to create the bad_words_ids, that is, the whole model vocab minus the words you want in the model
    def create_bad_words_ids(bad_words_ids, words_to_delete):
      for pictogram in range(len(words_to_delete)):
        if words_to_delete[pictogram] in bad_words_ids:
          bad_words_ids.remove(words_to_delete[pictogram])
      return bad_words_ids
    
    bad_words_ids = create_bad_words_ids(bad_words_ids=bad_words_ids, words_to_delete=words_to_delete)
    print(bad_words_ids)
    
    

    Hope it helps, cheers