Search code examples
pythontensorflowkerastensorflow-hub

tensorflow_hub returns NotImplementedError when saving a keras model


Tried to save a Keras model by following the TensorFlow tutorial.

from tensorflow.keras.layers import Dense, Input
from tensorflow.keras.models import Model
import tensorflow_hub as hub
import tensorflow as tf

module_url = "https://tfhub.dev/google/universal-sentence-encoder/4"

input1 = Input(shape=[], dtype=tf.string)
loaded_obj = hub.load(module_url)
emb = hub.KerasLayer(loaded_obj, trainable=False)
embedding_layer = emb(input1)
dense1 = Dense(units=512, activation="relu")(embedding_layer)
outputs = Dense(1, activation="sigmoid")(dense1)

model = Model(inputs=input1, outputs=outputs)
model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["AUC"])

tf.saved_model.save(loaded_obj, "fine_tuned")
model.save("model.h5", include_optimizer=False) 

The last line gives

NotImplementedError                       Traceback (most recent call last) /var/folders/x9/2_wr3dnn4pv0v_t3k096rrt00000gn/T/ipykernel_49946/3843995216.py in <module>
     17 
     18 tf.saved_model.save(loaded_obj, "fine_tuned")
---> 19 model.save("model.h5", include_optimizer=False)

~/anaconda3/envs/tensorflow/lib/python3.7/site-packages/keras/utils/traceback_utils.py in error_handler(*args, **kwargs)
     65     except Exception as e:  # pylint: disable=broad-except
     66       filtered_tb = _process_traceback_frames(e.__traceback__)
---> 67       raise e.with_traceback(filtered_tb) from None
     68     finally:
     69       del filtered_tb

~/anaconda3/envs/tensorflow/lib/python3.7/site-packages/tensorflow_hub/keras_layer.py in get_config(self)
    330           "Can only generate a valid config for `hub.KerasLayer(handle, ...)`"
    331           "that uses a string `handle`.\n\n"
--> 332           "Got `type(handle)`: {}".format(type(self._handle)))
    333     config["handle"] = self._handle
    334 

NotImplementedError: Can only generate a valid config for `hub.KerasLayer(handle, ...)`that uses a string `handle`.

Got `type(handle)`: <class 'tensorflow.python.saved_model.load.Loader._recreate_base_user_object.<locals>._UserObject'>

How could I fix this? model.to_json() also returns the same NotImplementedError.

print("tensorflow:", tf.__version__)
print("tensorflow_hub:", hub.__version__)
print("keras:", tf.keras.__version__)

tensorflow: 2.7.0
tensorflow_hub: 0.12.0
keras: 2.7.0

Solution

  • According to this post:

    hub.KerasLayer cannot save a Keras model config (as required for saving to HDF5) if initialized with a Python callable instead of a string [...]

    So either use a literal string in hub.KerasLayer:

    from tensorflow.keras.layers import Dense, Input
    from tensorflow.keras.models import Model
    import tensorflow_hub as hub
    import tensorflow as tf
    
    module_url = "https://tfhub.dev/google/universal-sentence-encoder/4"
    
    input1 = Input(shape=[], dtype=tf.string)
    loaded_obj = hub.load(module_url)
    emb = hub.KerasLayer("https://tfhub.dev/google/universal-sentence-encoder/4", trainable=False)
    embedding_layer = emb(input1)
    dense1 = Dense(units=512, activation="relu")(embedding_layer)
    outputs = Dense(1, activation="sigmoid")(dense1)
    
    model = Model(inputs=input1, outputs=outputs)
    model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["AUC"])
    
    tf.saved_model.save(loaded_obj, "fine_tuned")
    model.save("model.h5", include_optimizer=False)
    

    Or save your model with the default SavedModel format:

    module_url = "https://tfhub.dev/google/universal-sentence-encoder/4"
    
    input1 = Input(shape=[], dtype=tf.string)
    loaded_obj = hub.load(module_url)
    emb = hub.KerasLayer(loaded_obj, trainable=False)
    embedding_layer = emb(input1)
    dense1 = Dense(units=512, activation="relu")(embedding_layer)
    outputs = Dense(1, activation="sigmoid")(dense1)
    
    model = Model(inputs=input1, outputs=outputs)
    model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["AUC"])
    
    tf.saved_model.save(loaded_obj, "fine_tuned")
    model.save("model", include_optimizer=False)