Search code examples
pythontensorflowtensorflow-datasetsbatch-normalization

How do you set the axes parameter in TensorFlow moments for batch normalization?


I am planning on implementing a batch normalization function similar to this blog (or just using tf.nn.batch_normalization) using tf.nn.moments to compute mean and variance, but I wish to do it for temporal data, both of vector and image type. I am generally having a little trouble understanding how to set the axes argument correctly in tf.nn.moments.

My input data for vector sequences has shape (batch, timesteps, channels), and my input data for image sequences has shape (batch, timesteps, height, width, 3) (note they are RGB images). In both cases I want the normalization to happen across the entire batch and across timesteps, meaning I am not trying to maintain separate mean/variance for different timesteps.

How to correctly set axes for different data types (e.g. image, vector) and for temporal/non-temporal?


Solution

  • The simplest way to think of it is - axes passed into axes will be collapsed, and statistics will be computed by slicing over the axes. Example:

    import tensorflow as tf
    
    x = tf.random.uniform((8, 10, 4))
    
    print(x, '\n')
    print(tf.nn.moments(x, axes=[0]), '\n')
    print(tf.nn.moments(x, axes=[0, 1]))
    
    Tensor("random_uniform:0", shape=(8, 10, 4), dtype=float32)
    
    (<tf.Tensor 'moments/Squeeze:0'   shape=(10, 4) dtype=float32>,
     <tf.Tensor 'moments/Squeeze_1:0' shape=(10, 4) dtype=float32>)
    
    (<tf.Tensor 'moments_1/Squeeze:0'   shape=(4,) dtype=float32>,
     <tf.Tensor 'moments_1/Squeeze_1:0' shape=(4,) dtype=float32>)
    

    From source, math_ops.reduce_mean is used to compute both mean and variance, which operates as, in pseudocode:

    # axes = [0]
    mean = (x[0, :, :] + x[1, :, :] + ... + x[7, :, :]) / 8
    mean.shape == (10, 4)  # each slice's shape is (10, 4), so sum's shape is also (10, 4)
    
    # axes = [0, 1]
    mean = (x[0, 0,  :] + x[1, 0,  :] + ... + x[7, 0,  :] +
            x[0, 1,  :] + x[1, 1,  :] + ... + x[7, 1,  :] +
            ... +
            x[0, 10, :] + x[1, 10, :] + ... + x[7, 10, :]) / (8 * 10)
    mean.shape == (4, ) # each slice's shape is (4, ), so sum's shape is also (4, )
    

    In other words, axes=[0] will compute (timesteps, channels) statistics with respect to samples - i.e. iterate over samples, compute mean & variance of (timesteps, channels) slices. Thus, for

    normalization to happen across the entire batch and across timesteps, meaning I am not trying to maintain separate mean/variance for different timesteps

    you just need to collapse the timesteps dimension (along samples), and compute statistics by iterating over both samples and timesteps:

    axes = [0, 1]
    

    Same story for images, except as you have two non-channel/sample dimensions, you'd do axes = [0, 1, 2] (to collapse samples, height, width).


    Pseudocode demo: see mean computation in action

    import tensorflow as tf
    import tensorflow.keras.backend as K
    import numpy as np
    
    x = tf.constant(np.random.randn(8, 10, 4))
    result1 = tf.add(x[0], tf.add(x[1], tf.add(x[2], tf.add(x[3], tf.add(x[4], 
                           tf.add(x[5], tf.add(x[6], x[7]))))))) / 8
    result2 = tf.reduce_mean(x, axis=0)
    print(K.eval(result1 - result2))
    
    # small differences per numeric imprecision
    [[ 2.77555756e-17  0.00000000e+00 -5.55111512e-17 -1.38777878e-17]
     [-2.77555756e-17  2.77555756e-17  0.00000000e+00 -1.38777878e-17]
     [ 0.00000000e+00 -5.55111512e-17  0.00000000e+00 -2.77555756e-17]
     [-1.11022302e-16  2.08166817e-17  2.22044605e-16  0.00000000e+00]
     [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
     [-5.55111512e-17  2.77555756e-17 -1.11022302e-16  5.55111512e-17]
     [ 0.00000000e+00  0.00000000e+00  0.00000000e+00 -2.77555756e-17]
     [ 0.00000000e+00  0.00000000e+00  0.00000000e+00 -5.55111512e-17]
     [ 0.00000000e+00 -3.46944695e-17 -2.77555756e-17  1.11022302e-16]
     [-5.55111512e-17  5.55111512e-17  0.00000000e+00  1.11022302e-16]]