Search code examples
pythonnlppytorchhuggingface-transformersbert-language-model

Save a Bert model with custom forward function and heads on Hugginface


I have created my own BertClassifier model, starting from a pretrained and then added my own classification heads composed by different layers. After the fine-tuning, I want to save the model using model.save_pretrained() but when I print it upload it from pretrained i don't see my classifier head. The code is the following. How can I save the all structure on my model and make it full accessible with AutoModel.from_preatrained('folder_path') ? . Thanks!

class BertClassifier(PreTrainedModel):
    """Bert Model for Classification Tasks."""
    config_class = AutoConfig
    def __init__(self,config, freeze_bert=True): #tuning only the head
        """
         @param    bert: a BertModel object
         @param    classifier: a torch.nn.Module classifier
         @param    freeze_bert (bool): Set `False` to fine-tune the BERT model
        """
        #super(BertClassifier, self).__init__()
        super().__init__(config)

        # Instantiate BERT model
        # Specify hidden size of BERT, hidden size of our classifier, and number of labels
        self.D_in = 1024 #hidden size of Bert
        self.H = 512
        self.D_out = 2
 
        # Instantiate the classifier head with some one-layer feed-forward classifier
        self.classifier = nn.Sequential(
            nn.Linear(self.D_in, 512),
            nn.Tanh(),
            nn.Linear(512, self.D_out),
            nn.Tanh()
        )
 


    def forward(self, input_ids, attention_mask):


         # Feed input to BERT
        outputs = self.bert(input_ids=input_ids,
                             attention_mask=attention_mask)
         
         # Extract the last hidden state of the token `[CLS]` for classification task
        last_hidden_state_cls = outputs[0][:, 0, :]
 
         # Feed input to classifier to compute logits
        logits = self.classifier(last_hidden_state_cls)
 
        return logits

configuration=AutoConfig.from_pretrained('Rostlab/prot_bert_bfd')
model = BertClassifier(config=configuration,freeze_bert=False)

Saving the model after fine-tuning

model.save_pretrained('path')

Loading the fine-tuned model

model = AutoModel.from_pretrained('path') 

Printing the model after loading shows I have as the last layer the following and missing my 2 linear layer:

 (output): BertOutput(
          (dense): Linear(in_features=4096, out_features=1024, bias=True)
          (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (adapters): ModuleDict()
          (adapter_fusion_layer): ModuleDict()
        )
      )
    )
  )
  (pooler): BertPooler(
    (dense): Linear(in_features=1024, out_features=1024, bias=True)
    (activation): Tanh()
  )
  (prefix_tuning): PrefixTuningPool(
    (prefix_tunings): ModuleDict()
  )
)

Solution

  • Maybe something is wrong with the config_class attribute inside your BertClassifier class. According to the documentation you need to create an additional config class which inherits form PretrainedConfig and initialises the model_type attribute with the name of your custom model.

    The BertClassifier's config_class has to be consistent with your custom config class type. Afterwards you can register your config and model with the following calls:

    AutoConfig.register('CustomModelName', CustomModelConfigClass)
    AutoModel.register(CustomModelConfigClass, CustomModelClass)
    

    And load your finetuned model with AutoModel.from_pretrained('YourCustomModelName')

    An incomplete example based on your code could look like this:

    class BertClassifierConfig(PretrainedConfig):
        model_type="BertClassifier"
    
    
    class BertClassifier(PreTrainedModel):
        config_class = BertClassifierConfig
        # ...
    
    
    configuration = BertClassifierConfig()
    bert_classifier = BertClassifier(configuration)
    
    # do your finetuning and save your custom model
    bert_classifier.save_pretrained("CustomModels/BertClassifier")
    
    # register your config and your model
    AutoConfig.register("BertClassifier", BertClassifierConfig)
    AutoModel.register(BertClassifierConfig, BertClassifier)
    
    # load your model with AutoModel
    bert_classifier_model = AutoModel.from_pretrained("CustomModels/BertClassifier")
    

    Printing the model output should be similiar to this:

        (pooler): BertPooler(
          (dense): Linear(in_features=768, out_features=768, bias=True)
          (activation): Tanh()
        )
      )
      (classifier): Sequential(
        (0): Linear(in_features=1024, out_features=512, bias=True)
        (1): Tanh()
        (2): Linear(in_features=512, out_features=2, bias=True)
        (3): Tanh()
        (4): Linear(in_features=2, out_features=512, bias=True)
        (5): Tanh()
      )
    

    Hope this helps.

    https://huggingface.co/docs/transformers/custom_models#registering-a-model-with-custom-code-to-the-auto-classes