Search code examples
pythonpytorchhuggingface-transformersbert-language-modelallennlp

How do I load a fine-tuned AllenNLP BERT-SRL model using BertPreTrainedModel.from_pretrained()?


I have fine-tuned a BERT model for semantic role labeling, using AllenNLP. This produces a model directory (serialization directory, if I recall?) that contains the following:

best.th
config.json
meta.json
metrics_epoch_0.json
metrics_epoch_10.json
metrics_epoch_11.json
metrics_epoch_12.json
metrics_epoch_13.json
metrics_epoch_14.json
metrics_epoch_1.json
metrics_epoch_2.json
metrics_epoch_3.json
metrics_epoch_4.json
metrics_epoch_5.json
metrics_epoch_6.json
metrics_epoch_7.json
metrics_epoch_8.json
metrics_epoch_9.json
metrics.json
model_state_e14_b0.th
model_state_e15_b0.th
model.tar.gz
out.log
training_state_e14_b0.th
training_state_e15_b0.th
vocabulary

Where vocabulary is a folder with labels.txt and non_padded_namespaces.txt.

I'd now like to use this fine-tuned model BERT model as the initialization when learning a related task, event extraction, using this library: https://github.com/wilsonlau-uw/BERT-EE (ie I want to exploit some transfer learning). The config.ini file has a line for fine_tuned_path, where I can specify an already-fine-tuned model that I want to use here. I provided the path to the AllenNLP serialization directory, and I got the following error:

2022-04-05 13:07:28,112 -  INFO - setting seed 23
2022-04-05 13:07:28,113 -  INFO - loading fine tuned model in /data/projects/SRL/ser_pure_clinical_bert-large_thyme_and_ontonotes/
Traceback (most recent call last):
  File "main.py", line 65, in <module>
    model = BERT_EE()
  File "/data/projects/SRL/BERT-EE/model.py", line 88, in __init__
    self.__build(self.use_fine_tuned)
  File "/data/projects/SRL/BERT-EE/model.py", line 118, in __build
    self.__get_pretrained(self.fine_tuned_path)
  File "/data/projects/SRL/BERT-EE/model.py", line 110, in __get_pretrained
    self.__model = BERT_EE_model.from_pretrained(path)
  File "/home/richier/anaconda3/envs/allennlp/lib/python3.7/site-packages/transformers/modeling_utils.py", line 1109, in from_pretrained
    f"Error no file named {[WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + '.index', FLAX_WEIGHTS_NAME]} found in "
OSError: Error no file named ['pytorch_model.bin', 'tf_model.h5', 'model.ckpt.index', 'flax_model.msgpack'] found in directory /data/projects/SRL/ser_pure_clinical_bert-large_thyme_and_ontonotes/ or `from_tf` and `from_flax` set to False.

Of course, the serialization directory doesn't have any of those files, hence the error. I tried unzipping model.tar.gz but it only has:

config.json
weights.th
vocabulary/
vocabulary/.lock
vocabulary/labels.txt
vocabulary/non_padded_namespaces.txt
meta.json

Digging into the codebase of the GitHub repo I linked above, I can see that BERT_EE_model inherits from BertPreTrainedModel from the transformers library, so the trick would seem to be getting the AllenNLP model into a format that BertPreTrainedModel.from_pretrained() can load...?

Any help would be greatly appreciated!


Solution

  • I believe I have figured this out. Basically, I had to re-load my model archive, access the underlying model and tokenizer, and then save those:

    from allennlp.models.archival import load_archive
    from allennlp_models.structured_prediction import SemanticRoleLabeler, srl, srl_bert
    
    archive = load_archive('ser_pure_clinical_bert-large_thyme_and_ontonotes/model.tar.gz')
    
    bert_model = archive.model.bert_model #type is transformers.models.bert.modeling_bert.BertModel
    bert_model.save_pretrained('ser_pure_clinical_bert-large_thyme_and_ontonotes_save_pretrained/')
    
    bert_tokenizer = archive.dataset_reader.bert_tokenizer
    bert_tokenizer.save_pretrained('ser_pure_clinical_bert-large_thyme_and_ontonotes_save_pretrained/')
    

    (This last part is probably less interesting to most folks, but also, in the config.ini I mentioned, the directory 'ser_pure_clinical_bert-large_thyme_and_ontonotes_save_pretrained' needed to be passed to the line pretrained_model_name_or_path not to fine_tuned_path.)