Search code examples
tensorflowcheckpoint

Albert_base : weights from ckpt not loaded properly when calling with bert-for-tf2


I wanted to fine-tune Albert_base with further mlm task, but I realized there is no pretrained ckpt file provided for albert-base. So my plan was to convert the saved_model(or model loaded from tf-hub) to checkpoint myself, and then pretrain albert-base using the code provided (https://github.com/google-research/ALBERT/blob/master/run_pretraining.py).

Before further pretraining, to check whether the conversion to ckpt was successful, I re-converted the ckpt file to saved_model format, and loaded it as keras layer using bert-for-tf2 (https://github.com/kpe/bert-for-tf2/tree/master/bert) However, when I loaded the re-converted albert_base, its embeddings were different from those from the one loaded from the original albert_base.

Here is how I converted the original saved_model to ckpt, and then back to saved_model. (I used tf vertsion = 1.15.0 on colab)

"""
Convert tf-hub module to checkpoint files.
"""
albert_module = hub.Module(
    "https://tfhub.dev/google/albert_base/2",
    trainable=True)
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, './albert/model_ckpt/albert_base')

"""
Save model loaded from ckpt in saved_model format.
"""
from tensorflow.python.saved_model import tag_constants

graph = tf.Graph()
with tf.Session(graph=graph) as sess:
    # Restore from checkpoint
    loader = tf.train.import_meta_graph('./albert/model_ckpt/albert_base.meta')
    loader.restore(sess, tf.train.latest_checkpoint('./albert/model_ckpt/'))

    # Export checkpoint to SavedModel
    builder = tf.saved_model.builder.SavedModelBuilder('./albert/saved_model')
    builder.add_meta_graph_and_variables(sess,
                                         [],
                                         strip_default_attrs=True)
    builder.save()    

Using bert-for-tf2, I load albert_base as a keras layer and build a simple module :

def load_pretrained_albert():
    model_name = "albert_base"
    model_dir = bert.fetch_tfhub_albert_model(model_name, ".models")
    model_params = bert.albert_params(model_name)
    l_bert = bert.BertModelLayer.from_params(model_params, name="albert")

    # use in Keras Model here, and call model.build()
    max_seq_len = 128

    l_input_ids = Input(shape=(max_seq_len,), dtype='int32', name="l_input_ids")

    output = l_bert(l_input_ids)                              # output: [batch_size, max_seq_len, hidden_size]
    pooled_output = AveragePooling1D(pool_size=max_seq_len, data_format="channels_last")(output)
    pooled_output = Flatten()(pooled_output)


    model = Model(inputs=[l_input_ids], outputs=[pooled_output])
    model.build(input_shape=(None, max_seq_len))

    bert.load_albert_weights(l_bert, model_dir)

    return model

The code above loads weights from the saved_model. The problem is that when I overwrite the original saved_model of albert_base with the one I re-converted from checkpoints, the resulting embeddings differ.

When I run the code above with the re-converted saved_model, the following warnings come up :

model = load_pretrained_albert()
Fetching ALBERT model: albert_base version: 2
Already  fetched:  albert_base.tar.gz
already unpacked at: .models\albert_base
loader: No value for:[albert_4/embeddings/word_embeddings/embeddings:0], i.e.:[bert/embeddings/word_embeddings] in:[.models\albert_base]
loader: No value for:[albert_4/embeddings/word_embeddings_projector/projector:0], i.e.:[bert/encoder/embedding_hidden_mapping_in/kernel] in:[.models\albert_base]
loader: No value for:[albert_4/embeddings/word_embeddings_projector/bias:0], i.e.:[bert/encoder/embedding_hidden_mapping_in/bias] in:[.models\albert_base]
loader: No value for:[albert_4/embeddings/position_embeddings/embeddings:0], i.e.:[bert/embeddings/position_embeddings] in:[.models\albert_base]
loader: No value for:[albert_4/embeddings/LayerNorm/gamma:0], i.e.:[bert/embeddings/LayerNorm/gamma] in:[.models\albert_base]
loader: No value for:[albert_4/embeddings/LayerNorm/beta:0], i.e.:[bert/embeddings/LayerNorm/beta] in:[.models\albert_base]
loader: No value for:[albert_4/encoder/layer_shared/attention/self/query/kernel:0], i.e.:[bert/encoder/transformer/group_0/inner_group_0/attention_1/self/query/kernel] in:[.models\albert_base]
loader: No value for:[albert_4/encoder/layer_shared/attention/self/query/bias:0], i.e.:[bert/encoder/transformer/group_0/inner_group_0/attention_1/self/query/bias] in:[.models\albert_base]
loader: No value for:[albert_4/encoder/layer_shared/attention/self/key/kernel:0], i.e.:[bert/encoder/transformer/group_0/inner_group_0/attention_1/self/key/kernel] in:[.models\albert_base]
loader: No value for:[albert_4/encoder/layer_shared/attention/self/key/bias:0], i.e.:[bert/encoder/transformer/group_0/inner_group_0/attention_1/self/key/bias] in:[.models\albert_base]
loader: No value for:[albert_4/encoder/layer_shared/attention/self/value/kernel:0], i.e.:[bert/encoder/transformer/group_0/inner_group_0/attention_1/self/value/kernel] in:[.models\albert_base]
loader: No value for:[albert_4/encoder/layer_shared/attention/self/value/bias:0], i.e.:[bert/encoder/transformer/group_0/inner_group_0/attention_1/self/value/bias] in:[.models\albert_base]
loader: No value for:[albert_4/encoder/layer_shared/attention/output/dense/kernel:0], i.e.:[bert/encoder/transformer/group_0/inner_group_0/attention_1/output/dense/kernel] in:[.models\albert_base]
loader: No value for:[albert_4/encoder/layer_shared/attention/output/dense/bias:0], i.e.:[bert/encoder/transformer/group_0/inner_group_0/attention_1/output/dense/bias] in:[.models\albert_base]
loader: No value for:[albert_4/encoder/layer_shared/attention/output/LayerNorm/gamma:0], i.e.:[bert/encoder/transformer/group_0/inner_group_0/LayerNorm/gamma] in:[.models\albert_base]
loader: No value for:[albert_4/encoder/layer_shared/attention/output/LayerNorm/beta:0], i.e.:[bert/encoder/transformer/group_0/inner_group_0/LayerNorm/beta] in:[.models\albert_base]
loader: No value for:[albert_4/encoder/layer_shared/intermediate/kernel:0], i.e.:[bert/encoder/transformer/group_0/inner_group_0/ffn_1/intermediate/dense/kernel] in:[.models\albert_base]
loader: No value for:[albert_4/encoder/layer_shared/intermediate/bias:0], i.e.:[bert/encoder/transformer/group_0/inner_group_0/ffn_1/intermediate/dense/bias] in:[.models\albert_base]
loader: No value for:[albert_4/encoder/layer_shared/output/dense/kernel:0], i.e.:[bert/encoder/transformer/group_0/inner_group_0/ffn_1/intermediate/output/dense/kernel] in:[.models\albert_base]
loader: No value for:[albert_4/encoder/layer_shared/output/dense/bias:0], i.e.:[bert/encoder/transformer/group_0/inner_group_0/ffn_1/intermediate/output/dense/bias] in:[.models\albert_base]
loader: No value for:[albert_4/encoder/layer_shared/output/LayerNorm/gamma:0], i.e.:[bert/encoder/transformer/group_0/inner_group_0/LayerNorm_1/gamma] in:[.models\albert_base]
loader: No value for:[albert_4/encoder/layer_shared/output/LayerNorm/beta:0], i.e.:[bert/encoder/transformer/group_0/inner_group_0/LayerNorm_1/beta] in:[.models\albert_base]
Done loading 0 BERT weights from: .models\albert_base into <bert.model.BertModelLayer object at 0x0000029687449D68> (prefix:albert_4). Count of weights not found in the checkpoint was: [22]. Count of weights with mismatched shape: [0]
Unused weights from saved model:
        module/bert/embeddings/LayerNorm/beta
        module/bert/embeddings/LayerNorm/gamma
        module/bert/embeddings/position_embeddings
        module/bert/embeddings/token_type_embeddings
        module/bert/embeddings/word_embeddings
        module/bert/encoder/embedding_hidden_mapping_in/bias
        module/bert/encoder/embedding_hidden_mapping_in/kernel
        module/bert/encoder/transformer/group_0/inner_group_0/LayerNorm/beta
        module/bert/encoder/transformer/group_0/inner_group_0/LayerNorm/gamma
        module/bert/encoder/transformer/group_0/inner_group_0/LayerNorm_1/beta
        module/bert/encoder/transformer/group_0/inner_group_0/LayerNorm_1/gamma
        module/bert/encoder/transformer/group_0/inner_group_0/attention_1/output/dense/bias
        module/bert/encoder/transformer/group_0/inner_group_0/attention_1/output/dense/kernel
        module/bert/encoder/transformer/group_0/inner_group_0/attention_1/self/key/bias
        module/bert/encoder/transformer/group_0/inner_group_0/attention_1/self/key/kernel
        module/bert/encoder/transformer/group_0/inner_group_0/attention_1/self/query/bias
        module/bert/encoder/transformer/group_0/inner_group_0/attention_1/self/query/kernel
        module/bert/encoder/transformer/group_0/inner_group_0/attention_1/self/value/bias
        module/bert/encoder/transformer/group_0/inner_group_0/attention_1/self/value/kernel
        module/bert/encoder/transformer/group_0/inner_group_0/ffn_1/intermediate/dense/bias
        module/bert/encoder/transformer/group_0/inner_group_0/ffn_1/intermediate/dense/kernel
        module/bert/encoder/transformer/group_0/inner_group_0/ffn_1/intermediate/output/dense/bias
        module/bert/encoder/transformer/group_0/inner_group_0/ffn_1/intermediate/output/dense/kernel
        module/bert/pooler/dense/bias
        module/bert/pooler/dense/kernel
        module/cls/predictions/output_bias
        module/cls/predictions/transform/LayerNorm/beta
        module/cls/predictions/transform/LayerNorm/gamma
        module/cls/predictions/transform/dense/bias
        module/cls/predictions/transform/dense/kernel

whereas when run with original albert_base, the warnings are just the following :

model = load_pretrained_albert()
Fetching ALBERT model: albert_base version: 2
Already  fetched:  albert_base.tar.gz
already unpacked at: .models\albert_base
Done loading 22 BERT weights from: .models\albert_base into <bert.model.BertModelLayer object at 0x0000029680196320> (prefix:albert_5). Count of weights not found in the checkpoint was: [0]. Count of weights with mismatched shape: [0]
Unused weights from saved model:
        bert/embeddings/token_type_embeddings
        bert/pooler/dense/bias
        bert/pooler/dense/kernel
        cls/predictions/output_bias
        cls/predictions/transform/LayerNorm/beta
        cls/predictions/transform/LayerNorm/gamma
        cls/predictions/transform/dense/bias
        cls/predictions/transform/dense/kernel

Based on my understanding, the weights are not loaded properly due to difference in names. Is there a way I can specify names to save as when saving in ckpt format? I feel like if, for instance, when saved as ckpt format, the weight 'module/bert/embeddings/LayerNorm/beta' is instead saved as 'bert/embeddings/LayerNorm/beta', the problem would be solved. How can I get rid of 'module/' parts?

I feel like I might have made the problem sound more complicated then it is, but I tried to explain the situation I am in as specifically as I could, just in case :)


Solution

  • Problem Solved! So the problem WAS actually the difference in tensor names. So I changed the names of tensors in the checkpoints using the following code (https://gist.github.com/batzner/7c24802dd9c5e15870b4b56e22135c96).

    Just need to change 'module/bert/....' to 'bert/....' and it's all good.