Search code examples
pythonmachine-learningtensorflowdeep-learningbatch-normalization

Restore trained tensorflow model, edit the value associated with a node, and save it


I have trained a model with tensorflow and used batch normalization during training. Batch normalization requires the user to pass a boolean, called is_training, to set whether the model is in training or testing phase.

When the model was trained, is_training was set as a constant as shown below

is_training = tf.constant(True, dtype=tf.bool, name='is_training')

I have saved the trained model, the files include checkpoint, .meta file, .index file, and a .data. I'd like to restore the model and run inference using it. The model can't be retrained. So, I'd like to restore the existing model, set the value of is_training to False and then save the model back. How can I edit the boolean value associated with that node, and save the model again?


Solution

  • You can use the input_map argument of tf.train.import_meta_graph to remap the graph tensor to a updated value.

    config = tf.ConfigProto(allow_soft_placement=True)
    with tf.Session(config=config) as sess:
        # define the new is_training tensor
        is_training = tf.constant(False, dtype=tf.bool, name='is_training')
    
        # now import the graph using the .meta file of the checkpoint
        saver = tf.train.import_meta_graph(
        '/path/to/model.meta', input_map={'is_training:0':is_training})
    
        # restore all weights using the model checkpoint 
        saver.restore(sess, '/path/to/model')
    
        # save updated graph and variables values
        saver.save(sess, '/path/to/new-model-name')