Search code examples
huggingface-transformers

How do I enforce a token not to be split by huggingface tokenizer?


I have a string such as "xxx yyy zzz" and I am using the BERT tokenizer from Huggingface:

from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

mylistoftoken = tokenizer.tokenize("xxx yyy zzz")

However, I want to be able to enforce that certain words (for example "abcd") should not be subtokenized into subwords ("aa" and "##bb" or something of that sort).

Is there a way for me to enforce that without post-processing the array of tokens and putting them back together?


Solution

  • There might be better solutions depending on your use case, but based on the information you provided, you are looking for add_tokens:

    from transformers import BertTokenizer
    
    t = BertTokenizer.from_pretrained("bert-base-uncased")
    print(t.tokenize("xxx yyy zzz abcd"))
    t.add_tokens(["yyy", "abcd"])
    print(t.tokenize("xxx yyy zzz abcd"))
    

    Output:

    ['xx', '##x', 'y', '##y', '##y', 'z', '##zz', 'abc', '##d']
    ['xx', '##x', 'yyy', 'z', '##zz', 'abcd']