Search code examples
variablestensorflowinference

Transfer parameters from training to inference graph


To not to carry optimizer and gradient nodes into inference environment, I'm trying to create two versions of graph - one with training nodes and the other one without.

And idea was to use tensorflow.train.Saver to pass variables from train graph version into inference graph version.

So I've tried the following:

# Create training graph
trainingGraph = tf.Graph()
with (trainingGraph.as_default()):
  trainOp, lossOp = self.CreateTrainingGraph()
  trainInitOp = tf.initialize_variables(tf.all_variables(), "init_variables")

  # Add saver op
  self.saverOp = tf.train.Saver()

# Create inference graph
inferenceGraph = tf.Graph()
with (inferenceGraph.as_default()):
  self.CreateInferenceGraph()

  # Add saver op, compatible with training saver
  tf.train.Saver(saver_def=self.saverOp.as_saver_def())

In this case CreateTrainingGraph() calls CreateInferenceGraph() and adds optimizer and loss on top of it.

For some reason, tf.train.Saver constructor doesn't add save/restore_all node into the inference graph (or I just don't understand what saver_def option does). I've tried empty constructor and

sess.run([model.saverOp._restore_op_name],
         { model.saverOp._filename_tensor_name : "Params/data.pb" })

failed with error

<built-in function delete_Status> returned a result with an error set

What is the proper way to achieve this?


Solution

  • When you construct your inference graph, you should be able to construct a tf.train.Saver() with no arguments, and it will construct the appropriate save and restore ops for you. You should then be able to call saver.restore(sess, filename) to restore the variables from a file.

    N.B. For the constructor to work with no arguments, (i) the variables in the inference graph (i.e. the result of tf.all_variables()) must be a subset of the variables in the training graph, and (ii) the corresponding variables must have exactly the same names. If either of these conditions doesn't hold, you will need to specify a variable name map to the saver constructor. (However, if self.CreateTrainingGraph() calls self.CreateInferenceGraph() before creating any other variables, and doesn't do anything different with tf.name_scope(), then this should be fine.)

    (The saver_def argument is infrequently used when you load in a graph—for example using tf.import_graph_def()—that already contains the save and restore ops from a previously created Saver. It will then create a Saver in your Python program that reuses those ops, and you will get a mysterious error if the graph does not contain those ops.)