Search code examples
pythontensorflowbatch-normalization

Batch Normalization causes huge difference between training and inference loss


I followed the instruction on Tensorflow's web page for tf.layers.batch_normalization to set the training be True when training and False when inference (valid and test).

However, the batch normalization always gives me huge difference between training and valid loss, for example:

2018-09-11 09:22:34: step 993, loss 1.23001, acc 0.488638
2018-09-11 09:22:35: step 994, loss 0.969551, acc 0.567364
2018-09-11 09:22:35: step 995, loss 1.31113, acc 0.5291
2018-09-11 09:22:35: step 996, loss 1.03135, acc 0.607861
2018-09-11 09:22:35: step 997, loss 1.16031, acc 0.549255
2018-09-11 09:22:36: step 998, loss 1.42303, acc 0.454694
2018-09-11 09:22:36: step 999, loss 1.33105, acc 0.496234
2018-09-11 09:22:36: step 1000, loss 1.14326, acc 0.527387
Round 4: valid
Loading from valid, 1383 samples available
2018-09-11 09:22:36: step 1000, loss 44.3765, acc 0.000743037
2018-09-11 09:22:36: step 1000, loss 36.9143, acc 0.0100708
2018-09-11 09:22:37: step 1000, loss 35.2007, acc 0.0304909
2018-09-11 09:22:37: step 1000, loss 39.9036, acc 0.00510307
2018-09-11 09:22:37: step 1000, loss 42.2612, acc 0.000225067
2018-09-11 09:22:37: step 1000, loss 29.9964, acc 0.0230831
2018-09-11 09:22:37: step 1000, loss 28.1444, acc 0.00278473

and sometimes even worse (for another model):

2018-09-11 09:19:39: step 591, loss 0.967038, acc 0.630745
2018-09-11 09:19:40: step 592, loss 1.26836, acc 0.406095
2018-09-11 09:19:40: step 593, loss 1.33029, acc 0.536824
2018-09-11 09:19:41: step 594, loss 0.809579, acc 0.651354
2018-09-11 09:19:41: step 595, loss 1.41018, acc 0.491683
2018-09-11 09:19:42: step 596, loss 1.37515, acc 0.462998
2018-09-11 09:19:42: step 597, loss 0.972473, acc 0.663277
2018-09-11 09:19:43: step 598, loss 1.01062, acc 0.624355
2018-09-11 09:19:43: step 599, loss 1.13029, acc 0.53893
2018-09-11 09:19:44: step 600, loss 1.41601, acc 0.502889
Round 2: valid
Loading from valid, 1383 samples available
2018-09-11 09:19:44: step 600, loss 23242.2, acc 0.204348
2018-09-11 09:19:44: step 600, loss 22038, acc 0.196325
2018-09-11 09:19:44: step 600, loss 22223, acc 0.0991791
2018-09-11 09:19:44: step 600, loss 22039.2, acc 0.220871
2018-09-11 09:19:45: step 600, loss 25587.3, acc 0.155427
2018-09-11 09:19:45: step 600, loss 12617.7, acc 0.481486
2018-09-11 09:19:45: step 600, loss 17226.6, acc 0.234989
2018-09-11 09:19:45: step 600, loss 18530.3, acc 0.321573
2018-09-11 09:19:45: step 600, loss 21043.5, acc 0.157935
2018-09-11 09:19:46: step 600, loss 17232.6, acc 0.412151
2018-09-11 09:19:46: step 600, loss 28958.8, acc 0.297459
2018-09-11 09:19:46: step 600, loss 22603.7, acc 0.146518
2018-09-11 09:19:46: step 600, loss 29485.6, acc 0.266186
2018-09-11 09:19:46: step 600, loss 26039.7, acc 0.215589

The batch normalization code I use:

def bn(inp, train_flag, name=None):
    return tf.layers.batch_normalization(inp, training=train_flag, name=name)

def gn(inp, groups=32):
    return tf.contrib.layers.group_norm(inp, groups=groups)

def conv(*args, padding='same', with_relu=True, with_bn=False,
         train_flag=None, with_gn=False, name=None, **kwargs):
    # inp, filters, kernel_size, strides
    use_bias = False if with_bn else True
    x = tf.layers.conv2d(*args, **kwargs, padding=padding,
                         kernel_initializer=xavier_initializer(),
                         use_bias=use_bias, name=name)
    if with_bn:
        bn_name = name+'/batchnorm' if name is not None else None
        x = bn(x, train_flag, name=bn_name)
    if with_gn: x = gn(x)
    if with_relu: x = relu(x)
    return x

After I remove the batch normalization layer, then the huge difference between training and validation loss would disappear.

The following code is used in optimization.

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):

The model is trained from scratch without transfer learning.

I followed the issue Batch Normalization layer gives significant difference between train and validation loss on the exact same data, and tried to reduce the momentum, but not work either.

I am wondering why it happens. I appreciate it very much if you could provide me some advice.

Added: train_flag is a placeholder used through whole model.


Solution

  • For my case, I wrongly only call update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) once.

    For multiple GPUs, it is needed to call tf.get_collection(tf.GraphKeys.UPDATE_OPS) for each GPU, before compute_gradients and after each subnetwork has been defined. Furthermore, after combining all towers of subnetworks, it is also needed to call it again before apply_gradients.

    Another way is that after the whole network has been defined (including all subnetworks), then call update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) to get the current update_ops. In this case we need two for loops, one for defining sebnetworks, one for computing gradients.

    An example is shown as follows:

    # Multiple GPUs
    tmp, l = [], 0
    for i in range(opt.gpu_num):
        r = min(l + opt.batch_split, opt.batchsize)
        with tf.device('/gpu:%d' % i), \
             tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
    
            print("Setting up networks on GPU", i)
            inp_ = tf.identity(inps[l:r])
            label_ = tf.identity(labels[l:r])
            for j, val in enumerate(setup_network(inp_, label_)): # loss, pred, accuracy
                if i == 0: tmp += [[]] # [[], [], []]
                tmp[j] += [val]
        l = r
    
    tmp += [[]]
    # Calculate update_ops after the network has been defined
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # possible batch normalization
    for i in range(opt.gpu_num):
        with tf.device('/gpu:%d' % i), \
             tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
    
             print("Setting up gradients on GPU", i)
             tmp[-1] += [setup_grad(optim, tmp[0][i])]
    

    Added:

    I also add the setup_grad function

    def setup_grad(optim, loss):
        # `compute_gradients`` will only run after update_ops have executed
        with tf.control_dependencies(update_ops):
            update_vars = None
            if opt.to_train is not None:
                update_vars = [tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=s)
                               for s in opt.to_train]
            total_loss = loss[0] + opt.seg_weight * loss[1]
            return optim.compute_gradients(total_loss, var_list=update_vars)
    

    and later apply_gradients as a reference.

    # `apply_gradients`` will only run after update_ops have executed
    with tf.control_dependencies(update_ops):
        if opt.clip_grad: grads = [(tf.clip_by_value(grad[0], -opt.clip_grad, opt.clip_grad), grad[1]) \
                                    if grad[0] is not None else grad for grad in grads]
        train_op = optim.apply_gradients(grads, global_step=global_step)
    

    If your batch size on each GPU is small, batch normalization might not help with the performance due to that Tensorflow currently not support sync batch normalization layer data between GPUs.