I have trained a model with tensorflow and used batch normalization during training. Batch normalization requires the user to pass a boolean, called is_training
, to set whether the model is in training or testing phase.
When the model was trained, is_training
was set as a constant as shown below
is_training = tf.constant(True, dtype=tf.bool, name='is_training')
I have saved the trained model, the files include checkpoint, .meta file, .index file, and a .data. I'd like to restore the model and run inference using it.
The model can't be retrained. So, I'd like to restore the existing model, set the value of is_training
to False
and then save the model back.
How can I edit the boolean value associated with that node, and save the model again?
You can use the input_map
argument of tf.train.import_meta_graph
to remap the graph tensor to a updated value.
config = tf.ConfigProto(allow_soft_placement=True)
with tf.Session(config=config) as sess:
# define the new is_training tensor
is_training = tf.constant(False, dtype=tf.bool, name='is_training')
# now import the graph using the .meta file of the checkpoint
saver = tf.train.import_meta_graph(
'/path/to/model.meta', input_map={'is_training:0':is_training})
# restore all weights using the model checkpoint
saver.restore(sess, '/path/to/model')
# save updated graph and variables values
saver.save(sess, '/path/to/new-model-name')