Search code examples
deep-learningpytorchbert-language-modelhuggingface-transformerstransfer-learning

How to add a multiclass multilabel layer on top of pretrained BERT model?


I am trying to do a multitask multiclass sentence classification task using the pretrained BERT model from the huggingface transformers library . I have tried to use the BERTForSequenceClassification model from there but the issue I am having is that I am not able to extend it for multiple tasks . I will try to make it more informative through this example.

Suppose we have four different tasks and for each sentence and for each task we have labels like this as follows in the examples:

  1. A :[ 'a' , 'b' , 'c' , 'd' ]
  2. B :[ 'e' , 'f' , 'g' , 'h' ]
  3. C :[ 'i' , 'j' , 'k' , 'l' ]
  4. D :[ 'm' , 'n' , 'o' , 'p' ]

Now , if I have a sentence for this model , I want the output to give me output for all the four different tasks (A,B,C,D).

This is what I was doing earlier

   model = BertForSequenceClassification.from_pretrained(
    "bert-base-uncased", # Use the 12-layer BERT model, with an uncased vocab.
    num_labels = 4, # The number of output labels--2 for binary classification.
                    # You can increase this for multi-class tasks.   
    output_attentions = False, # Whether the model returns attentions weights.
    output_hidden_states = False, # Whether the model returns all hidden-states.
)

Then I tried to implement a CustomBERT model like this :

class CustomBERTModel(nn.Module):
    def __init__(self):
          super(CustomBERTModel, self).__init__()
          self.bert = BertModelForSequenceClassification.from_pretrained("bert-base-uncased")
          ### New layers:
          self.linear1 = nn.Linear(768, 256)
          self.linear2 = nn.Linear(256, num_classes) ## num_classes is the number of classes in this example

    def forward(self, ids, mask):
          sequence_output, pooled_output = self.bert(
               ids, 
               attention_mask=mask)

          # sequence_output has the following shape: (batch_size, sequence_length, 768)
          linear1_output = self.linear1(sequence_output[:,0,:].view(-1,768)) 
          linear2_output = self.linear2(linear2_output)

          return linear2_output

I have went through the answers to questions similar to it available earlier but none of them appeared to answer my question . I have tried to get through all the points which I think can be helpful for the understanding of my problem and would try to clear further more in case of any descrepancies made by me in the explaination of the question . Any answers related to this will be very much helpful .


Solution

  • You should use BertModel and not BertModelForSequenceClassification, as BertModelForSequenceClassification adds a linear layer for classification on top of BERT model and uses CrossEntropyLoss, which is meant for multiclass classification.

    Hence, first use BertModel instead of BertModelForSequenceClassification:

    class CustomBERTModel(nn.Module):
        def __init__(self):
              super(CustomBERTModel, self).__init__()
              self.bert = BertModel.from_pretrained("bert-base-uncased")
              ### New layers:
              self.linear1 = nn.Linear(768, 256)
              self.linear2 = nn.Linear(256, 4) ## as you have 4 classes in the output
              self.sig = nn.functional.sigmoid()
    
        def forward(self, ids, mask):
              sequence_output, pooled_output = self.bert(
                   ids, 
                   attention_mask=mask)
    
              # sequence_output has the following shape: (batch_size, sequence_length, 768)
              linear1_output = self.linear1(sequence_output[:,0,:].view(-1,768)) 
              linear2_output = self.linear2(linear2_output)
              linear2_output = self.sig(linear2_output)
    
              return linear2_output
    
    

    Next, multilabel classification uses 'Sigmoid' activation instead of 'Softmax' (Here, the sigmoid layer is added in the above code)

    Further, for multilabel classification, you need to use BCELoss instead of CrossEntropyLoss.