Search code examples
multidimensional-arraymachine-learningdeep-learningdata-sciencebatch-normalization

Batch normalization: fixed samples or different samples by dimension?


Some questions came to me as I read a paper 'Batch Normalization : Accelerating Deep Network Training by Reducing Internal Covariate Shift'.

In the paper, it says:

Since m examples from training data can estimate mean and variance of all training data, we use mini-batch to train batch normalization parameters.

My question is :

Are they choosing m examples and then fitting batch norm parameters concurrently, or choosing different set of m examples for each input dimension?

E.g. training set is composed of x(i) = (x1,x2,...,xn) : n-dimension for fixed batch M = {x(1),x(2),...,x(N)}, perform fitting all gamma1~gamman and beta1~betan.

vs

For gamma_i, beta_i picking different batch M_i = {x(1)_i,...,x(m)_i}


Solution

  • I haven't found this question on cross-validated and data-science, so I can only answer it here. Feel free to migrate if necessary.

    The mean and variance are computed for all dimensions in each mini-batch at once, using moving averages. Here's how it looks like in code in TF:

    mean, variance = tf.nn.moments(incoming, axis)
    update_moving_mean = moving_averages.assign_moving_average(moving_mean, mean, decay)
    update_moving_variance = moving_averages.assign_moving_average(moving_variance, variance, decay)
    with tf.control_dependencies([update_moving_mean, update_moving_variance]):
      return tf.identity(mean), tf.identity(variance)
    

    You shouldn't worry about technical details, here's what's going on:

    • First the mean and variance of the whole batch incoming are computed, along batch axis. Both of them are vectors (more precisely, tensors).
    • Then current values moving_mean and moving_variance are updated by an assign_moving_average call, which basically computes this: variable * decay + value * (1 - decay).

    Every time batchnorm gets executed, it knows one current batch and some statistic of previous batches.