Search code examples
tensorflowtensorflow-serving

How to grab the input and output tensors after load a tf.saved_model in python


Suppose I saved a model with following code tf.saved_model.simple_save(sess, export_dir, in={'input_x': x, 'input_y':y}, out={'output_z':z})

And now I load back the saved model in another python program as with tf.Session() as sess: tf.saved_model.loader.load(sess, ['serve'], export_dir)

Now the question is how can I grab handles of the x, y, z tensors by the 'input_x', 'input_y', 'output_z' keys I specified in the input/output argument when call simple_save() method?

The only solution I found online relies on explicitly naming the x, y, z tensor when create them, and then use these names to retrieve them from graph, which seems to be quite redundant as we have specified keys for them in calling simple_save().


Solution

  • I had exactly your problem and after some investigation (poor TF documentation in my opinion) i found the next solution:

    Use the returned MetaGraphDef object to find your inputs \ outputs name mapping.

            graph = tf.Graph()
        with graph.as_default():
            metagraph = tf.saved_model.loader.load(sess, [tag_constants.SERVING],save_path)
    
        inputs_mapping = dict(metagraph.signature_def['serving_default'].inputs)
        outputs_mapping = dict(metagraph.signature_def['serving_default'].outputs)
    

    This code will give you the mapping between the names you supplied while saving to a "TensorInfo" object and from him you can easily get the mapped tensor name, for example:

        my_input = inputs_mapping['my_input_name'].name
        my_input_t = graph.get_tensor_by_name(my_input)