Let's say someone hands me a TF SavedModel and I would like to replicate this model on the 4 GPUs I have on my machine so I can run inference in parallel on batches of data. Are there any good examples of how to do this?
I can load a saved model in this way:
def load_model(self, saved_model_dirpath):
'''Loads a model from a saved model directory - this should
contain a .pb file and a variables directory'''
signature_key = tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
input_key = 'input'
output_key = 'output'
meta_graph_def = tf.saved_model.loader.load(self.sess, [tf.saved_model.tag_constants.SERVING],
saved_model_dirpath)
signature = meta_graph_def.signature_def
input_tensor_name = signature[signature_key].inputs[input_key].name
output_tensor_name = signature[signature_key].outputs[output_key].name
self.input_tensor = self.sess.graph.get_tensor_by_name(input_tensor_name)
self.output_tensor = self.sess.graph.get_tensor_by_name(output_tensor_name)
..but this would require that I have a handle to the session. For models that I have written myself, I would have access to the inference function and I could just call it and wrap it using with tf.device()
, but in this case, I'm not sure how to extract the inference function out of a Saved Model. Should I load 4 separate sessions or is there a better way? Couldn't find much documentation on this, but apologies in advance if I missed something. Thanks!
There is no support for this use case in TensorFlow at the moment. Unfortunately, "replicating the inference function" based only on the SavedModel (which is basically the computation graph with some metadata), is a fairly complex (and brittle, if implemented) graph transformation problem.
If you don't have access to the source code that produced this model, your best bet is to load the SavedModel 4 times into 4 separate graphs, rewriting the target device to the corresponding GPU each time. Then, run each graph/session separately.
Note that you can invoke sess.run()
multiple times concurrently since sess.run()
releases the GIL for the time of actual computation. All you need is several Python threads.