To not to carry optimizer and gradient nodes into inference environment, I'm trying to create two versions of graph - one with training nodes and the other one without.
And idea was to use tensorflow.train.Saver
to pass variables from train graph version into inference graph version.
So I've tried the following:
# Create training graph
trainingGraph = tf.Graph()
with (trainingGraph.as_default()):
trainOp, lossOp = self.CreateTrainingGraph()
trainInitOp = tf.initialize_variables(tf.all_variables(), "init_variables")
# Add saver op
self.saverOp = tf.train.Saver()
# Create inference graph
inferenceGraph = tf.Graph()
with (inferenceGraph.as_default()):
self.CreateInferenceGraph()
# Add saver op, compatible with training saver
tf.train.Saver(saver_def=self.saverOp.as_saver_def())
In this case CreateTrainingGraph()
calls CreateInferenceGraph()
and adds optimizer and loss on top of it.
For some reason, tf.train.Saver
constructor doesn't add save/restore_all
node into the inference graph (or I just don't understand what saver_def
option does). I've tried empty constructor and
sess.run([model.saverOp._restore_op_name],
{ model.saverOp._filename_tensor_name : "Params/data.pb" })
failed with error
<built-in function delete_Status> returned a result with an error set
What is the proper way to achieve this?
When you construct your inference graph, you should be able to construct a tf.train.Saver()
with no arguments, and it will construct the appropriate save and restore ops for you. You should then be able to call saver.restore(sess, filename)
to restore the variables from a file.
N.B. For the constructor to work with no arguments, (i) the variables in the inference graph (i.e. the result of tf.all_variables()
) must be a subset of the variables in the training graph, and (ii) the corresponding variables must have exactly the same names. If either of these conditions doesn't hold, you will need to specify a variable name map to the saver constructor. (However, if self.CreateTrainingGraph()
calls self.CreateInferenceGraph()
before creating any other variables, and doesn't do anything different with tf.name_scope()
, then this should be fine.)
(The saver_def
argument is infrequently used when you load in a graph—for example using tf.import_graph_def()
—that already contains the save and restore ops from a previously created Saver. It will then create a Saver in your Python program that reuses those ops, and you will get a mysterious error if the graph does not contain those ops.)