Search code examples
tensorflowbatch-normalization

significance of "trainable" and "training" flag in tf.layers.batch_normalization


What is the significance of "trainable" and "training" flag in tf.layers.batch_normalization? How are these two different during training and prediction?


Solution

  • The batch norm has two phases:

    1. Training:
       -  Normalize layer activations using `moving_avg`, `moving_var`, `beta` and `gamma` 
         (`training`* should be `True`.)
       -  update the `moving_avg` and `moving_var` statistics. 
         (`trainable` should be `True`)
    2. Inference:
       -  Normalize layer activations using `beta` and `gamma`.
          (`training` should be `False`)
    

    Example code to illustrate few cases:

    #random image
    img = np.random.randint(0,10,(2,2,4)).astype(np.float32)
    
    # batch norm params initialized
    beta = np.ones((4)).astype(np.float32)*1 # all ones 
    gamma = np.ones((4)).astype(np.float32)*2 # all twos
    moving_mean = np.zeros((4)).astype(np.float32) # all zeros
    moving_var = np.ones((4)).astype(np.float32) # all ones
    
    #Placeholders for input image
    _input = tf.placeholder(tf.float32, shape=(1,2,2,4), name='input')
    
    #batch Norm
    out = tf.layers.batch_normalization(
           _input,
           beta_initializer=tf.constant_initializer(beta),
           gamma_initializer=tf.constant_initializer(gamma),
           moving_mean_initializer=tf.constant_initializer(moving_mean),
           moving_variance_initializer=tf.constant_initializer(moving_var),
           training=False, trainable=False)
    
    
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    init_op = tf.global_variables_initializer()
    
     ## 2. Run the graph in a session 
    
     with tf.Session() as sess:
    
        # init the variables   
        sess.run(init_op)
    
        for i in range(2):
            ops, o = sess.run([update_ops, out], feed_dict={_input: np.expand_dims(img, 0)})
            print('beta', sess.run('batch_normalization/beta:0'))
            print('gamma', sess.run('batch_normalization/gamma:0'))
            print('moving_avg',sess.run('batch_normalization/moving_mean:0'))
            print('moving_variance',sess.run('batch_normalization/moving_variance:0'))
            print('out', np.round(o))
            print('')
    

    When training=False and trainable=False:

      img = [[[4., 5., 9., 0.]...
      out = [[ 9. 11. 19.  1.]... 
      The activation is scaled/shifted using gamma and beta.
    

    When training=True and trainable=False:

      out = [[ 2.  2.  3. -1.] ...
      The activation is normalized using `moving_avg`, `moving_var`, `gamma` and `beta`. 
      The averages are not updated.
    

    When traning=True and trainable=True:

      The out is same as above, but the `moving_avg` and `moving_var` gets updated to new values.
    
      moving_avg [0.03249997 0.03499997 0.06499994 0.02749997]
      moving_variance [1.0791667 1.1266665 1.0999999 1.0925]