I have created my own BertClassifier model, starting from a pretrained and then added my own classification heads composed by different layers. After the fine-tuning, I want to save the model using model.save_pretrained() but when I print it upload it from pretrained i don't see my classifier head.
The code is the following. How can I save the all structure on my model and make it full accessible with AutoModel.from_preatrained('folder_path')
?
. Thanks!
class BertClassifier(PreTrainedModel):
"""Bert Model for Classification Tasks."""
config_class = AutoConfig
def __init__(self,config, freeze_bert=True): #tuning only the head
"""
@param bert: a BertModel object
@param classifier: a torch.nn.Module classifier
@param freeze_bert (bool): Set `False` to fine-tune the BERT model
"""
#super(BertClassifier, self).__init__()
super().__init__(config)
# Instantiate BERT model
# Specify hidden size of BERT, hidden size of our classifier, and number of labels
self.D_in = 1024 #hidden size of Bert
self.H = 512
self.D_out = 2
# Instantiate the classifier head with some one-layer feed-forward classifier
self.classifier = nn.Sequential(
nn.Linear(self.D_in, 512),
nn.Tanh(),
nn.Linear(512, self.D_out),
nn.Tanh()
)
def forward(self, input_ids, attention_mask):
# Feed input to BERT
outputs = self.bert(input_ids=input_ids,
attention_mask=attention_mask)
# Extract the last hidden state of the token `[CLS]` for classification task
last_hidden_state_cls = outputs[0][:, 0, :]
# Feed input to classifier to compute logits
logits = self.classifier(last_hidden_state_cls)
return logits
configuration=AutoConfig.from_pretrained('Rostlab/prot_bert_bfd')
model = BertClassifier(config=configuration,freeze_bert=False)
Saving the model after fine-tuning
model.save_pretrained('path')
Loading the fine-tuned model
model = AutoModel.from_pretrained('path')
Printing the model after loading shows I have as the last layer the following and missing my 2 linear layer:
(output): BertOutput(
(dense): Linear(in_features=4096, out_features=1024, bias=True)
(LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(adapters): ModuleDict()
(adapter_fusion_layer): ModuleDict()
)
)
)
)
(pooler): BertPooler(
(dense): Linear(in_features=1024, out_features=1024, bias=True)
(activation): Tanh()
)
(prefix_tuning): PrefixTuningPool(
(prefix_tunings): ModuleDict()
)
)
Maybe something is wrong with the config_class
attribute inside your BertClassifier
class. According to the documentation you need to create an additional config class which inherits form PretrainedConfig
and initialises the model_type
attribute with the name of your custom model.
The BertClassifier's config_class
has to be consistent with your custom config class type.
Afterwards you can register your config and model with the following calls:
AutoConfig.register('CustomModelName', CustomModelConfigClass)
AutoModel.register(CustomModelConfigClass, CustomModelClass)
And load your finetuned model with AutoModel.from_pretrained('YourCustomModelName')
An incomplete example based on your code could look like this:
class BertClassifierConfig(PretrainedConfig):
model_type="BertClassifier"
class BertClassifier(PreTrainedModel):
config_class = BertClassifierConfig
# ...
configuration = BertClassifierConfig()
bert_classifier = BertClassifier(configuration)
# do your finetuning and save your custom model
bert_classifier.save_pretrained("CustomModels/BertClassifier")
# register your config and your model
AutoConfig.register("BertClassifier", BertClassifierConfig)
AutoModel.register(BertClassifierConfig, BertClassifier)
# load your model with AutoModel
bert_classifier_model = AutoModel.from_pretrained("CustomModels/BertClassifier")
Printing the model output should be similiar to this:
(pooler): BertPooler(
(dense): Linear(in_features=768, out_features=768, bias=True)
(activation): Tanh()
)
)
(classifier): Sequential(
(0): Linear(in_features=1024, out_features=512, bias=True)
(1): Tanh()
(2): Linear(in_features=512, out_features=2, bias=True)
(3): Tanh()
(4): Linear(in_features=2, out_features=512, bias=True)
(5): Tanh()
)
Hope this helps.