Search code examples
tokenizebert-language-modelhuggingface-transformershuggingface-tokenizers

Is there a way to get the location of the substring from which a certain token has been produced in BERT?


I am feeding sentences to a BERT model (Hugging Face library). These sentences get tokenized with a pretrained tokenizer. I know that you can use the decode function to go back from tokens to strings.

string = tokenizer.decode(...)

However, the reconstruction is not perfect. If you use an uncased pretrained model, the uppercase letters get lost. Also, if the tokenizer splits a word into 2 tokens, the second token will start with '##'. For example, the word 'coronavirus' gets split into 2 tokens: 'corona' and '##virus'.

So my question is: is there a way to get the indices of the substring from which every token is created? For example, take the string "Tokyo to report nearly 370 new coronavirus cases, setting new single-day record". The 9th token is the token corresponding to 'virus'.

['[CLS]', 'tokyo', 'to', 'report', 'nearly', '370', 'new', 'corona', '##virus', 'cases', ',', 'setting', 'new', 'single', '-', 'day', 'record', '[SEP]']

I want something that tells me that the token '##virus' comes from the 'virus' substring in the original string, which is located between the indices 37 and 41 of the original string.

sentence = "Tokyo to report nearly 370 new coronavirus cases, setting new single-day record"
print(sentence[37:42]) # --> outputs 'virus

Solution

  • As far as I know their is no built-in method for that, but you can create one by yourself:

    import re
    from transformers import BertTokenizer
    
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    
    sentence = "Tokyo to report nearly 370 new coronavirus cases, setting new single-day record"
    
    b = []
    b.append(([101],))
    for m in re.finditer(r'\S+', sentence):
      w = m.group(0)
      t = (tokenizer.encode(w, add_special_tokens=False), (m.start(), m.end()-1))
    
      b.append(t)
    
    b.append(([102],))
    
    b
    

    Output:

    [([101],),
     ([5522], (0, 4)),
     ([2000], (6, 7)),
     ([3189], (9, 14)),
     ([3053], (16, 21)),
     ([16444], (23, 25)),
     ([2047], (27, 29)),
     ([21887, 23350], (31, 41)),
     ([3572, 1010], (43, 48)),
     ([4292], (50, 56)),
     ([2047], (58, 60)),
     ([2309, 1011, 2154], (62, 71)),
     ([2501], (73, 78)),
     ([102],)]