Search code examples
deep-learningnlpbert-language-modelhuggingface-transformerstransformer-model

BertModel or BertForPreTraining


I want to use Bert only for embedding and use the Bert output as an input for a classification net that I will build from scratch.

I am not sure if I want to do finetuning for the model.

I think the relevant classes are BertModel or BertForPreTraining.

BertForPreTraining head contains two "actions": self.predictions is MLM (Masked Language Modeling) head is what gives BERT the power to fix the grammar errors, and self.seq_relationship is NSP (Next Sentence Prediction); usually refereed as the classification head.

class BertPreTrainingHeads(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.predictions = BertLMPredictionHead(config)
        self.seq_relationship = nn.Linear(config.hidden_size, 2)

I think the NSP isn't relevant for my task so I can "override" it. what does the MLM do and is it relevant for my goal or should I use the BertModel?


Solution

  • You should be using BertModel instead of BertForPreTraining.

    BertForPreTraining is used to train bert on Masked Language Model (MLM) and Next Sentence Prediction (NSP) tasks. They are not meant for classification.

    BERT model simply gives the output of the BERT model, you can then finetune the BERT model along with the classifier that you build on top of it. For classification, if its just a single layer on top of BERT model, you can directly go with BertForSequenceClassification.

    In anycase, if you just want to take the output of BERT model and learn your classifier (without fine-tuning BERT model), then you can freeze the Bert model weights using:

    model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
    
    for param in model.bert.bert.parameters():
        param.requires_grad = False
    

    The above code is borrowed from here