Suppose I saved a model with following code
tf.saved_model.simple_save(sess, export_dir, in={'input_x': x, 'input_y':y}, out={'output_z':z})
And now I load back the saved model in another python program as
with tf.Session() as sess:
tf.saved_model.loader.load(sess, ['serve'], export_dir)
Now the question is how can I grab handles of the x, y, z tensors by the 'input_x', 'input_y', 'output_z' keys I specified in the input/output argument when call simple_save() method?
The only solution I found online relies on explicitly naming the x, y, z tensor when create them, and then use these names to retrieve them from graph, which seems to be quite redundant as we have specified keys for them in calling simple_save().
I had exactly your problem and after some investigation (poor TF documentation in my opinion) i found the next solution:
Use the returned MetaGraphDef object to find your inputs \ outputs name mapping.
graph = tf.Graph()
with graph.as_default():
metagraph = tf.saved_model.loader.load(sess, [tag_constants.SERVING],save_path)
inputs_mapping = dict(metagraph.signature_def['serving_default'].inputs)
outputs_mapping = dict(metagraph.signature_def['serving_default'].outputs)
This code will give you the mapping between the names you supplied while saving to a "TensorInfo" object and from him you can easily get the mapped tensor name, for example:
my_input = inputs_mapping['my_input_name'].name
my_input_t = graph.get_tensor_by_name(my_input)