Search code examples
tensorflowbatch-normalization

Tensorflow save/restore batch norm


I trained a model with batch norm in Tensorflow. I would like to save the model and restore it for further using. The batch norm is done by

def batch_norm(input, phase):
    return tf.layers.batch_normalization(input, training=phase)

where the phase is True during training and False during testing.

It seems like simply calling

saver = tf.train.Saver()
saver.save(sess, savedir + "ckpt")

would not work well because when I restore the model it first says restored successfully. It also says Attempting to use uninitialized value batch_normalization_585/beta if I just run one node in the graph. Is this related to not saving the model properly or something else that I've missed?


Solution

  • I also had the "Attempting to use uninitialized value batch_normalization_585/beta" error. This comes from the fact that by declaring the saver with the empty brackets like this:

             saver = tf.train.Saver() 
    

    The saver will save the variables contained in tf.trainable_variables() which do not contain the moving average of the batch normalization. To include this variables into the saved ckpt you need to do:

             saver = tf.train.Saver(tf.global_variables())
    

    Which saves ALL the variables, so it is very memory consuming. Or you must identify the variables that have moving avg or variance and save them by declaring them like:

             saver = tf.train.Saver(tf.trainable_variables() + list_of_extra_variables)