Search code examples
tensorflowseq2seq

How to save a seq2seq model in TensorFlow 2.x?


I'm following the "Neural machine translation with attention" tutorial from TensorFlow docs, but can't figure out how to save the model as a SavedModel file.

As seen in the docs, I can save a checkpoint fairly easily, but afaik that's not very useful when integrating with other applications. Does anyone know to save the whole "model", even though they're not using tf.keras.Model?

Docs: https://www.tensorflow.org/tutorials/text/nmt_with_attention


Solution

  • As explained here, there are 2 saving mechanism in tensorflow Checkpoints and SavedModel.

    If the code (the training one or here, the tutorial) will always be available then you can just restore and use the model with checkpoints.

    In order to have a SavedModel, you would need to rewrite the code as a class CustomModule(tf.Module) and be careful at

    When you save a tf.Module, any tf.Variable attributes, tf.function-decorated methods, and tf.Modules found via recursive traversal are saved. (See the Checkpoint tutorial for more about this recursive traversal.) However, any Python attributes, functions, and data are lost. This means that when a tf.function is saved, no Python code is saved. If no Python code is saved, how does SavedModel know how to restore the function? Briefly, tf.function works by tracing the Python code to generate a ConcreteFunction (a callable wrapper around tf.Graph). When saving a tf.function, you're really saving the tf.function's cache of ConcreteFunctions. To learn more about the relationship between tf.function and ConcreteFunctions, see the tf.function guide.

    More information here