Search code examples

Load a model as DPRQuestionEncoder in HuggingFace

I would like to load the BERT's weights (or whatever transformer) into a DPRQuestionEncoder architecture, such that I can use the HuggingFace save_pretrained method and plug the saved model into the RAG architecture to do end-to-end fine-tuning.

from transformers import DPRQuestionEncoder
model = DPRQuestionEncoder.from_pretrained('bert-base-uncased')

But I got the following error

You are using a model of type bert to instantiate a model of type dpr. This is not supported for all configurations of models and can yield errors.

NotImplementedErrorTraceback (most recent call last)
<ipython-input-27-1f1b990b906b> in <module>
----> 1 model = DPRQuestionEncoder.from_pretrained(model_name)
      2 #

/opt/conda/lib/python3.8/site-packages/transformers/ in from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
   1211                     )
-> 1213             model, missing_keys, unexpected_keys, error_msgs = cls._load_state_dict_into_model(
   1214                 model, state_dict, pretrained_model_name_or_path, _fast_init=_fast_init
   1215             )

/opt/conda/lib/python3.8/site-packages/transformers/ in _load_state_dict_into_model(cls, model, state_dict, pretrained_model_name_or_path, _fast_init)
   1286             )
   1287             for module in unintialized_modules:
-> 1288                 model._init_weights(module)
   1290         # copy state_dict so _load_from_state_dict can modify it

/opt/conda/lib/python3.8/site-packages/transformers/ in _init_weights(self, module)
    515         Initialize the weights. This method should be overridden by derived class.
    516         """
--> 517         raise NotImplementedError(f"Make sure `_init_weigths` is implemented for {self.__class__}")
    519     def tie_weights(self):

NotImplementedError: Make sure `_init_weigths` is implemented for <class 'transformers.models.dpr.modeling_dpr.DPRQuestionEncoder'>

I am using the last version of Transformers.


  • As already mentioned in the comments, DPRQuestionEncoder does currently not provide any functionality to load other models. I still recommend creating your own class that inherits from DPRQuestionEncoder that loads your custom model and adjusts its method.

    But you asked in the comments if there is another way, and yes there is in case the parameters of your model and the model that your DPRQuestionEncoder object is holding are completely the same. Please have a look at the commented example below:

    from transformers import BertModel
    # here I am just loading a bert model that represents your model
    ordinary_bert = BertModel.from_pretrained("bert-base-uncased")
    import torch
    from transformers import DPRQuestionEncoder
    # now we load the state dict (i.e. weights and bias) of your model
    ordinary_bert_state_dict = torch.load('this_is_a_bert/pytorch_model.bin')
    # here we create a DPRQuestionEncoder object 
    # the facebook/dpr-question_encoder-single-nq-base has the same parameters as bert-base-uncased
    # You can compare the respective configs or model.parameters to be sure 
    model = DPRQuestionEncoder.from_pretrained('facebook/dpr-question_encoder-single-nq-base')
    # we need to create the same keys (i.e. layer names) as your target model facebook/dpr-question_encoder-single-nq-base
    ordinary_bert_state_dict = {f"question_encoder.bert_model.{k}":v for k,v in ordinary_bert_state_dict.items()}
    # now we can load the bert-base-uncased weights into your DPRQuestionEncoder object 

    It works from a technical perspective but I can not tell you how it will perform for your task.