Search code examples
huggingface-transformershuggingface-tokenizers

How to know which words are encoded with unknown tokens in HuggingFace BertTokenizer?


I use the following code to count how many % of words are encoded to unknown tokens.

paragraph_chinese = '...' # It is a long paragraph from a text file.
from transformers import AutoTokenizer, BertTokenizer
tokenizer_bart = BertTokenizer.from_pretrained("fnlp/bart-base-chinese")
encoded_chinese_bart = tokenizer_bart.encode(paragraph_chinese)
unk_token_id_bart = tokenizer_bart.convert_tokens_to_ids(["[UNK]"])
len_paragraph_chinese   = len(paragraph_chinese)

unk_token_cnt_chinese_bart   = encoded_chinese_bart.count(unk_token_id_bart[0])
print("BART Unknown Token count in Chinese Paragraph:", unk_token_cnt_chinese_bart, "(" + str(unk_token_cnt_chinese_bart * 100 / len_paragraph_chinese) + "%)")
print(type(tokenizer_bart))

which prints:

BART Unknown Token count in Chinese Paragraph: 1 (0.015938795027095953%)
<class 'transformers.models.bert.tokenization_bert.BertTokenizer'>

My question is: I noticed there is one unknown token. How can I know which word causes this unknown token?

p.s. I tried print(encoded_chinese_bart), but it is a list of Token IDs.

Using transformers 4.28.1


Solution

  • When you use the BertTokenizerFast instead of the "slow" version, you will get a BatchEncoding object that gives you access to several convenient methods that allow you to map a token back to the original string.

    The following code uses the token_to_chars method:

    from transformers import BertTokenizerFast
    
    # just an example
    paragraph_chinese = '马云 Kočka 祖籍浙江省嵊县 Kočka 现嵊州市' 
    
    tokenizer_bart = BertTokenizerFast.from_pretrained("fnlp/bart-base-chinese")
    encoded_chinese_bart = tokenizer_bart(paragraph_chinese)
    unk_token_id_bart = tokenizer_bart.unk_token_id
    len_paragraph_chinese   = len(paragraph_chinese)
    
    unk_token_cnt_chinese_bart   = encoded_chinese_bart.input_ids.count(unk_token_id_bart)
    print(f'BART Unknown Token count in Chinese Paragraph: {unk_token_cnt_chinese_bart} ({unk_token_cnt_chinese_bart * 100 / len_paragraph_chinese}%)')
    
    #find all indices
    unk_indices = [i for i, x in enumerate(encoded_chinese_bart.input_ids) if x == unk_token_id_bart]
    for unk_i in unk_indices:
      start, stop = encoded_chinese_bart.token_to_chars(unk_i)
      print(f"At {start}:{stop}: {paragraph_chinese[start:stop]}")
    

    Original:

    BART Unknown Token count in Chinese Paragraph: 2 (7.407407407407407%)
    At 3:8: Kočka
    At 17:22: Kočka