Search code examples
pytorchbert-language-modelhuggingface-transformerslogits

How does the BERT model select the label ordering?


I'm training BertForSequenceClassification for a classification task. My dataset consists of 'contains adverse effect' (1) and 'does not contain adverse effect' (0). The dataset contains all of the 1s and then the 0s after (the data isn't shuffled). For training I've shuffled my data and get the logits. From what I've understood, the logits are the probability distributions before softmax. An example logit is [-4.673831, 4.7095485]. Does the first value correspond to the label 1 (contains AE) because it appears first in the dataset, or label 0. Any help would be appreciated thanks.


Solution

  • The first value corresponds to label 0 and the second value corresponds to label 1. What BertForSequenceClassification does is feeding the output of the pooler to a linear layer (after a dropout which I will ignore in this answer). Let's look at the following example:

    from torch import nn
    from transformers import BertModel, BertTokenizer
    t = BertTokenizer.from_pretrained('bert-base-uncased')
    m = BertModel.from_pretrained('bert-base-uncased')
    i = t.encode_plus('This is an example.', return_tensors='pt')
    o = m(**i)
    print(o.pooler_output.shape)
    

    Output:

    torch.Size([1, 768])
    

    The pooled_output is a tensor of shape [batch_size,hidden_size] and represents the contextualized (i.e. attention was applied) [CLS] token of your input sequences. This tensor is feed to a linear layer to calculate the logits of your sequence:

    classificationLayer = nn.Linear(768,2)
    logits = classificationLayer(o.pooler_output)
    

    When we normalize these logits we can see that the linear layer predicts that our input should belong to label 1:

    print(nn.functional.softmax(logits,dim=-1))
    

    Output (will differ since the linear layer is initialed randomly):

    tensor([[0.1679, 0.8321]], grad_fn=<SoftmaxBackward>)
    

    The linear layer applies a linear transformation: y=xA^T+b and you can already see that the linear layer is not aware of your labels. It 'only' has a weights matrix of size [2,768] to produce logits of size [1,2] (i.e.: first row corresponds to the first value and second row to the second):

    import torch:
    
    logitsOwnCalculation = torch.matmul(o.pooler_output,  classificationLayer.weight.transpose(0,1))+classificationLayer.bias
    print(nn.functional.softmax(logitsOwnCalculation,dim=-1))
    

    Output:

    tensor([[0.1679, 0.8321]], grad_fn=<SoftmaxBackward>)
    

    The BertForSequenceClassification model learns by applying a CrossEntropyLoss. This loss function produces a small loss when the logits for a certain class (label in your case) deviate only slightly from the expectation. That means the CrossEntropyLoss is the one that lets your model learn that the first logit should be high when the input does not contain adverse effect or small when it contains adverse effect. You can check this for our example with the following:

    loss_fct = nn.CrossEntropyLoss()
    label0 = torch.tensor([0]) #does not contain adverse effect
    label1 = torch.tensor([1]) #contains adverse effect
    print(loss_fct(logits, label0))
    print(loss_fct(logits, label1))
    

    Output:

    tensor(1.7845, grad_fn=<NllLossBackward>)
    tensor(0.1838, grad_fn=<NllLossBackward>)