Search code examples
pythontensorflowkerasdeep-learningbert-language-model

how to save and load custom siamese bert model


I am following this tutorial on how to train a siamese bert network:

https://keras.io/examples/nlp/semantic_similarity_with_bert/

all good, but I am not sure what is the best way to save the model after train it and save it. any suggestion?

I was trying with

model.save('models/bert_siamese_v1')

which creates a folder with save_model.bp keras_metadata.bp and two subfolders (variables and assets)

then I try to load it with:

model.load_weights('models/bert_siamese_v1/')

and it gives me this error:

2022-03-08 14:11:52.567762: W tensorflow/core/util/tensor_slice_reader.cc:95] Could not open models/bert_siamese_v1/: Failed precondition: models/bert_siamese_v1; Is a directory: perhaps your file is in a different file format and you need to use a different restore operator?

what is the best way to proceed?


Solution

  • Try using tf.saved_model.save to save your model:

    tf.saved_model.save(model, 'models/bert_siamese_v1')
    model = tf.saved_model.load('models/bert_siamese_v1')
    

    The warning you get during saving can apparently be ignored. After loading your model, you can use it for inference f(test_data):

    f = model.signatures["serving_default"]
    x1 = tf.random.uniform((1, 128), maxval=100, dtype=tf.int32)
    x2 = tf.random.uniform((1, 128), maxval=100, dtype=tf.int32)
    x3 = tf.random.uniform((1, 128), maxval=100, dtype=tf.int32)
    print(f)
    print(f(attention_masks = x1, input_ids = x2, token_type_ids = x3))
    
    ConcreteFunction signature_wrapper(*, token_type_ids, attention_masks, input_ids)
      Args:
        attention_masks: int32 Tensor, shape=(None, 128)
        input_ids: int32 Tensor, shape=(None, 128)
        token_type_ids: int32 Tensor, shape=(None, 128)
      Returns:
        {'dense': <1>}
          <1>: float32 Tensor, shape=(None, 3)
    {'dense': <tf.Tensor: shape=(1, 3), dtype=float32, numpy=array([[0.40711606, 0.13456087, 0.45832306]], dtype=float32)>}