Search code examples
bert-language-modelhuggingface-transformerstransformer-modellanguage-model

Assigning weights during testing the bert model


I have a basic conceptual doubt. When i train a bert model on sentence say:

Train: "went to get loan from bank" 
Test :"received education loan from bank"

How does the test sentence assigns the weights for each token because i however dont pass exact sentence for testing and there is a slight addition of words like "education" which change the context slightly

Assuming such context is not trained in my model how the weights are assigned for each token in my bert before i fine tune further

If i confuse with my question, simply put i am trying to understand how the weights get assigned during testing if a slight variation in context occurs that was not trained on.


Solution

  • The vector representation of a token (keep in mind that token != word) is stored in an embedding layer. When we load the 'bert-base-uncased' model, we can see that it "knows" 30522 tokens and that the vector representation of each token consists of 768 elements:

    from transformers import BertModel
    bert= BertModel.from_pretrained('bert-base-uncased')
    print(bert.embeddings.word_embeddings)
    

    Output:

    Embedding(30522, 768, padding_idx=0)
    

    This embedding layer is not aware of any strings but of ids. For example, the vector representation of the id 101 is:

    print(bert.embeddings.word_embeddings.weight[101])
    

    Output:

    tensor([ 1.3630e-02, -2.6490e-02, -2.3503e-02, -7.7876e-03,  8.5892e-03,
            -7.6645e-03, -9.8808e-03,  6.0184e-03,  4.6921e-03, -3.0984e-02,
             1.8883e-02, -6.0093e-03, -1.6652e-02,  1.1684e-02, -3.6245e-02,
             8.3482e-03, -1.2112e-03,  1.0322e-02,  1.6692e-02, -3.0354e-02,
            ...
             5.4162e-03, -3.0037e-02,  8.6773e-03, -1.7942e-03,  6.6826e-03,
            -1.1929e-02, -1.4076e-02,  1.6709e-02,  1.6860e-03, -3.3842e-03,
             8.6805e-03,  7.1340e-03,  1.5147e-02], grad_fn=<SelectBackward>)
    

    Everything that is outside of the "known" ids is not processable by BERT. To answer your question we need to look at the component that maps a string to the ids. This component is called a tokenizer. There are different tokenization approaches. BERT uses a WordPiece tokenizer which is a subword algorithm. This algorithm replaces everything that can not be created from its vocabulary with an unknown token that is part of the vocabulary ([UNK] in the original implementation, id: 100).

    Please have a look at the following small example in which a WordPiece tokenizer is trained from scratch to confirm that beheaviour:

    from tokenizers import BertWordPieceTokenizer
    path ='file_with_your_trainings_sentence.txt'
    tokenizer = BertWordPieceTokenizer()
    tokenizer.train(files=path, vocab_size=30000, special_tokens=['[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]'])
    otrain = tokenizer.encode("went to get loan from bank")
    otest =  tokenizer.encode("received education loan from bank")
    
    print('Vocabulary size: {}'.format(tokenizer.get_vocab_size()))
    print('Train tokens: {}'.format(otrain.tokens))
    print('Test tokens: {}'.format(otest.tokens))
    

    Output:

    Vocabulary size: 27
    Train tokens: ['w', '##e', '##n', '##t', 't', '##o', 'g', '##e', '##t', 'l', '##o', '##an', 'f', '##r', '##o', '##m', 'b', '##an', '##k']
    Test tokens: ['[UNK]', '[UNK]', 'l', '##o', '##an', 'f', '##r', '##o', '##m', 'b', '##an', '##k']