I'm using the Huggingface Transformer package and BERT with PyTorch. I'm trying to do 4-way sentiment classification and am using BertForSequenceClassification to build a model that eventually leads to a 4-way softmax at the end.
My understanding from reading the BERT paper is that the final dense vector for the input CLS
token serves as a representation of the whole text string:
The first token of every sequence is always a special classification token ([CLS]). The final hidden state corresponding to this token is used as the aggregate sequence representation for classification tasks.
So, does BertForSequenceClassification
actually train and use this vector to perform the final classification?
The reason I ask is because when I print(model)
, it is not obvious to me that the CLS
vector is being used.
model = BertForSequenceClassification.from_pretrained(
model_config,
num_labels=num_labels,
output_attentions=False,
output_hidden_states=False
)
print(model)
Here is the bottom of the output:
(11): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
(pooler): BertPooler(
(dense): Linear(in_features=768, out_features=768, bias=True)
(activation): Tanh()
)
)
(dropout): Dropout(p=0.1, inplace=False)
(classifier): Linear(in_features=768, out_features=4, bias=True)
I see that there is a pooling layer BertPooler
leading to a Dropout
leading to a Linear
which presumably performs the final 4-way softmax. However, the use of the BertPooler
is not clear to me. Is it operating on only the hidden state of CLS
, or is it doing some kind of pooling over hidden states of all the input tokens?
Thanks for any help.
The short answer: Yes, you are correct. Indeed, they use the CLS token (and only that) for BertForSequenceClassification
.
Looking at the implementation of the BertPooler
reveals that it is using the first hidden state, which corresponds to the [CLS]
token.
I briefly checked one other model (RoBERTa) to see whether this is consistent across models. Here, too, classification only takes place based on the [CLS]
token, albeit less obvious (check lines 539-542 here).