Search code examples
pythonnlphuggingface-transformersbert-language-model

Bert model splits words by its own


I am tokenizing the input words using bert model. The code is :

tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased',do_lower_case = False)
model = BertModel.from_pretrained("bert-base-multilingual-cased", add_pooling_layer=False, output_hidden_states=True, output_attentions=True)

marked_text =  text + " [SEP]"
    tokenized_text = tokenizer.tokenize(marked_text)
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
    print(tokenized_text)
    print(indexed_tokens)

The model I used is from HuggingFace.

My goal is to print the embedded vectors of all words Bert model has, so I searched and found that this model has 119296 tokens available.

I don't know this number of the tokens is reason, but the model splits the words by its own, which is unwanted for me.

for example,


only -> [only]
ONLY -> [ON,L,Y]

stradivarius -> ['St', '##radi', '##vari', '##us']

Is this natural Bert thing or I am doing something wrong ?


Solution

  • You are not doing anything wrong. Bert uses a so-called wordpiece subword tokenizer as a compromise for meaningful embeddings and acceptable memory consumption between a character-level (small vocabulary) and a word-level tokenizer (large vocabulary).

    A common approach to retrieve word embeddings from a subword-based model is to take the mean of the respective tokens. The code below shows you have you can retrieve the word embeddings (non-contextualized and contextualized) by taking the mean. It uses a fasttokenizer to utilize the methods of the BatchEncoding object.

    import torch
    from transformers import BertTokenizerFast, BertModel
    
    t = BertTokenizerFast.from_pretrained('bert-base-multilingual-cased')
    # whole model
    m = BertModel.from_pretrained("bert-base-multilingual-cased")
    # token embedding layer
    embedding_layer = m.embeddings.word_embeddings
    
    sample_sentence = 'This is an example with token-embeddings and word-embeddings'
    encoded = t([sample_sentence])
    # The BatchEncoding object allows us to map the token back to the string indices
    print(*[(token_id, encoded.token_to_chars(idx)) for idx, token_id in enumerate(encoded.input_ids[0])], sep="\n")
    # And we can also check the mapping of word to token indices
    print(*[(word, encoded.word_to_tokens(idx)) for idx, word in enumerate(sample_sentence.split())], sep="\n")
    

    Output:

    (101, None)
    (10747, CharSpan(start=0, end=4))
    (10124, CharSpan(start=5, end=7))
    (10151, CharSpan(start=8, end=10))
    (14351, CharSpan(start=11, end=18))
    (10169, CharSpan(start=19, end=23))
    (18436, CharSpan(start=24, end=27))
    (10136, CharSpan(start=27, end=29))
    (118, CharSpan(start=29, end=30))
    (10266, CharSpan(start=30, end=32))
    (33627, CharSpan(start=32, end=35))
    (13971, CharSpan(start=35, end=39))
    (10107, CharSpan(start=39, end=40))
    (10111, CharSpan(start=41, end=44))
    (12307, CharSpan(start=45, end=49))
    (118, CharSpan(start=49, end=50))
    (10266, CharSpan(start=50, end=52))
    (33627, CharSpan(start=52, end=55))
    (13971, CharSpan(start=55, end=59))
    (10107, CharSpan(start=59, end=60))
    (102, None)
    ('This', TokenSpan(start=1, end=2))
    ('is', TokenSpan(start=2, end=3))
    ('an', TokenSpan(start=3, end=4))
    ('example', TokenSpan(start=4, end=5))
    ('with', TokenSpan(start=5, end=6))
    ('token-embeddings', TokenSpan(start=6, end=8))
    ('and', TokenSpan(start=8, end=9))
    ('word-embeddings', TokenSpan(start=9, end=13))
    

    To retrieve the word embeddings:

    with torch.inference_mode():
      token_embeddings = embedding_layer(encoded.convert_to_tensors("pt").input_ids).squeeze()
      # we need the attention mechanism of the whole model to get the contextualized token representations
      contextualized_token_embeddings = m(**encoded.convert_to_tensors("pt")).last_hidden_state.squeeze()
    
    def fetch_word_embeddings(sample_sentence:str, encoded, embeddings:torch.Tensor) -> dict[str,torch.Tensor]:
      word_embeddings = {}
      for idx, word in enumerate(sample_sentence.split()):
        start, end = encoded.word_to_tokens(idx)
        word_embeddings[word] = embeddings[start:end].mean(dim=0)
      return word_embeddings
    
    word_embeddings = fetch_word_embeddings(sample_sentence, encoded, token_embeddings)
    contextualized_word_embeddings = fetch_word_embeddings(sample_sentence, encoded, contextualized_token_embeddings)
    print(word_embeddings["token-embeddings"])
    print(contextualized_word_embeddings["token-embeddings"])
    

    Output:

    tensor([ 1.2455e-02, -3.8478e-02,  8.0834e-03, ..., -1.8502e-02,  1.1511e-02, -6.5307e-02])
    tensor([-5.1564e-01, -1.6266e-01, -3.9420e-01, ..., -5.9969e-02,  3.0784e-01, -3.4451e-01])