Search code examples
pythontransformer-modelhuggingface-transformersbert-language-model

How does BertForSequenceClassification classify on the CLS vector?


Background:

Following along with this question when using bert to classify sequences the model uses the "[CLS]" token representing the classification task. According to the paper:

The first token of every sequence is always a special classification token ([CLS]). The final hidden state corresponding to this token is used as the aggregate sequence representation for classification tasks.

Looking at the huggingfaces repo their BertForSequenceClassification utilizes the bert pooler method:

class BertPooler(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output

We can see they take the first token (CLS) and use this as a representation for the whole sentence. Specifically they perform hidden_states[:, 0] which looks a lot like its taking the first element from each state rather than taking the first tokens hidden state?

My Question:

What I don't understand is how do they encode the information from the entire sentence into this token? Is the CLS token a regular token which has its own embedding vector that "learns" the sentence level representation? Why can't we just use the average of the hidden states (the output of the encoder) and use this to classify?

EDIT: After thinking a little about it: Because we use the CLS tokens hidden state to predict, is the CLS tokens embedding being trained on the task of classification as this is the token being used to classify (thus being the major contributor to the error which gets propagated to its weights?)


Solution

  • Is the CLS token a regular token which has its own embedding vector that "learns" the sentence level representation?

    Yes:

    from transformers import BertTokenizer, BertModel
    
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    model = BertModel.from_pretrained('bert-base-uncased')
    
    clsToken = tokenizer.convert_tokens_to_ids('[CLS]') 
    print(clsToken)
    #or
    print(tokenizer.cls_token, tokenizer.cls_token_id)
    
    print(model.get_input_embeddings()(torch.tensor(clsToken)))
    

    Output:

    101
    [CLS] 101
    tensor([ 1.3630e-02, -2.6490e-02, -2.3503e-02, -7.7876e-03,  8.5892e-03,
            -7.6645e-03, -9.8808e-03,  6.0184e-03,  4.6921e-03, -3.0984e-02,
             1.8883e-02, -6.0093e-03, -1.6652e-02,  1.1684e-02, -3.6245e-02,
             ...
             5.4162e-03, -3.0037e-02,  8.6773e-03, -1.7942e-03,  6.6826e-03,
            -1.1929e-02, -1.4076e-02,  1.6709e-02,  1.6860e-03, -3.3842e-03,
             8.6805e-03,  7.1340e-03,  1.5147e-02], grad_fn=<EmbeddingBackward>)
    

    You can get a list of all other special tokens for your model with:

    print(tokenizer.all_special_tokens)
    

    Output:

    ['[CLS]', '[UNK]', '[PAD]', '[SEP]', '[MASK]']
    

    What I don't understand is how do they encode the information from the entire sentence into this token?

    and

    Because we use the CLS tokens hidden state to predict, is the CLS tokens embedding being trained on the task of classification as this is the token being used to classify (thus being the major contributor to the error which gets propagated to its weights?)

    Also yes. As you have already stated in your question BertForSequenceClassification utilizes the BertPooler to train the linear layer on top of Bert:

    #outputs contains the output of BertModel and the second element is the pooler output
    pooled_output = outputs[1]
    
    pooled_output = self.dropout(pooled_output)
    logits = self.classifier(pooled_output)
    
    #...loss calculation based on logits and the given labels
    

    Why can't we just use the average of the hidden states (the output of the encoder) and use this to classify?

    I can't really answer this in general, but why do you think this would be easier or better as a linear layer? You also need to train the hidden layers to produce an output where the average maps to your class. Therefore you also need an "average layer" to be the major contributor to your loss. In general when you can show that it leads to better results instead of the current approach, nobody will reject it.