Search code examples
kerastransformer-modelbert-language-model

TFBertMainLayer gets less accuracy compared to TFBertModel


I had a problem with saving weights of TFBertModel wrapped in Keras. the problem is described here in GitHub issue and here in Stack Overflow.The solution proposed in both cases is to use

 config = BertConfig.from_pretrained(transformer_model_name)
 bert = TFBertMainLayer(config=config,trainable=False)

instead of

 bert = TFBertModel.from_pretrained(transformer_model_name, trainable=False)

The problem is that when I change my model to the former code, the accuracy decreases by 10 percent.While the parameters count in both cases are the same. I wonder what is the reason and how can be prevented?


Solution

  • It seems like the performance regression in the code snippet that instantiates MainLayer directly occurs because the pre-trained weights are not being loaded. You can load the weights by either:

    1. Calling TFBertModel.from_pretrained and grabbing the MainLayer from the loaded TFBertModel
    2. Creating the MainLayer directly, then loading the weights in a similar way to from_pretrained

    Why This Happens

    When you call TFBertModel.from_pretrained, it uses the function TFPreTrainedModel.from_pretrained (via inheritance) which handles a few things, including downloading, caching, and loading the model weights.

    class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
        ...
        @classmethod
        def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
            ...
            # Load model
            if pretrained_model_name_or_path is not None:
                if os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
                # Load from a TF 2.0 checkpoint
                archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
                ...
                resolved_archive_file = cached_path(
                        archive_file,
                        cache_dir=cache_dir,
                        force_download=force_download,
                        proxies=proxies,
                        resume_download=resume_download,
                        local_files_only=local_files_only,
                )
                ...
                model.load_weights(resolved_archive_file, by_name=True)
    

    (If you read the actual code, a lot has been ...'ed out above).

    However, when you instantiate TFBertMainLayer directly, it doesn't do any of this set up work.

    @keras_serializable
    class TFBertMainLayer(tf.keras.layers.Layer):
        config_class = BertConfig
    
        def __init__(self, config, **kwargs):
            super().__init__(**kwargs)
            self.num_hidden_layers = config.num_hidden_layers
            self.initializer_range = config.initializer_range
            self.output_attentions = config.output_attentions
            self.output_hidden_states = config.output_hidden_states
            self.return_dict = config.use_return_dict
            self.embeddings = TFBertEmbeddings(config, name="embeddings")
            self.encoder = TFBertEncoder(config, name="encoder")
            self.pooler = TFBertPooler(config, name="pooler")
       
       ... rest of the class
    

    Essentially, you need to make sure these weights are being loaded.

    Solutions

    (1) Using TFAutoModel.from_pretrained

    You can rely on transformers.TFAutoModel.from_pretrained to load the model, then just grab the MainLayer field from the specific subclass of TFPreTrainedModel. For example, if you wanted to access a distilbert main layer, it would look like:

        model = transformers.TFAutoModel.from_pretrained(`distilbert-base-uncased`)
        assert isinstance(model, TFDistilBertModel)
        main_layer = transformer_model.distilbert
    

    You can see in modeling_tf_distilbert.html that the MainLayer is a field of the model. This is less code and less duplication, but has a few disadvantages. It's less easy to change the pre-trained model you're going to use, because now you're depending on the fieldname, if you change the model type, you'll have to change the field name (for example in TFAlbertModel the MainLayer field is called albert). In addition, this doesn't seem to be the intended way to use huggingface, so this could change under your nose, and your code could break with huggingface updates.

    class TFDistilBertModel(TFDistilBertPreTrainedModel):
        def __init__(self, config, *inputs, **kwargs):
            super().__init__(config, *inputs, **kwargs)
            self.distilbert = TFDistilBertMainLayer(config, name="distilbert")  # Embeddings
    
    [DOCS]    @add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING)
        @add_code_sample_docstrings(
            tokenizer_class=_TOKENIZER_FOR_DOC,
            checkpoint="distilbert-base-uncased",
            output_type=TFBaseModelOutput,
            config_class=_CONFIG_FOR_DOC,
        )
        def call(self, inputs, **kwargs):
            outputs = self.distilbert(inputs, **kwargs)
            return outputs
    

    (2) Re-implementing the weight loading logic from from_pretrained

    You can do this by essentially copy/pasting the parts of from_pretrained that are relevant to loading weights. This also has some serious disadvantages, you'll be duplicating logic that that can fall out of sync with the huggingface libraries. Though you could likely write it in a way that is more flexible and robust to underlying model name changes.

    Conclusion

    Ideally this is something that will get fixed internally by the huggingface team, either by providing a standard function to create a MainLayer, wrapping the weight loading logic into its own function that can be called, or by supporting serialization on the model class.