I am planning on implementing a batch normalization function similar to this blog (or just using tf.nn.batch_normalization
) using tf.nn.moments
to compute mean and variance, but I wish to do it for temporal data, both of vector and image type. I am generally having a little trouble understanding how to set the axes
argument correctly in tf.nn.moments
.
My input data for vector sequences has shape (batch, timesteps, channels)
, and my input data for image sequences has shape (batch, timesteps, height, width, 3)
(note they are RGB images). In both cases I want the normalization to happen across the entire batch and across timesteps, meaning I am not trying to maintain separate mean/variance for different timesteps.
How to correctly set axes
for different data types (e.g. image, vector) and for temporal/non-temporal?
The simplest way to think of it is - axes passed into axes
will be collapsed, and statistics will be computed by slicing over the axes
. Example:
import tensorflow as tf
x = tf.random.uniform((8, 10, 4))
print(x, '\n')
print(tf.nn.moments(x, axes=[0]), '\n')
print(tf.nn.moments(x, axes=[0, 1]))
Tensor("random_uniform:0", shape=(8, 10, 4), dtype=float32)
(<tf.Tensor 'moments/Squeeze:0' shape=(10, 4) dtype=float32>,
<tf.Tensor 'moments/Squeeze_1:0' shape=(10, 4) dtype=float32>)
(<tf.Tensor 'moments_1/Squeeze:0' shape=(4,) dtype=float32>,
<tf.Tensor 'moments_1/Squeeze_1:0' shape=(4,) dtype=float32>)
From source, math_ops.reduce_mean
is used to compute both mean
and variance
, which operates as, in pseudocode:
# axes = [0]
mean = (x[0, :, :] + x[1, :, :] + ... + x[7, :, :]) / 8
mean.shape == (10, 4) # each slice's shape is (10, 4), so sum's shape is also (10, 4)
# axes = [0, 1]
mean = (x[0, 0, :] + x[1, 0, :] + ... + x[7, 0, :] +
x[0, 1, :] + x[1, 1, :] + ... + x[7, 1, :] +
... +
x[0, 10, :] + x[1, 10, :] + ... + x[7, 10, :]) / (8 * 10)
mean.shape == (4, ) # each slice's shape is (4, ), so sum's shape is also (4, )
In other words, axes=[0]
will compute (timesteps, channels)
statistics with respect to samples
- i.e. iterate over samples
, compute mean & variance of (timesteps, channels)
slices. Thus, for
normalization to happen across the entire batch and across timesteps, meaning I am not trying to maintain separate mean/variance for different timesteps
you just need to collapse the timesteps
dimension (along samples
), and compute statistics by iterating over both samples
and timesteps
:
axes = [0, 1]
Same story for images, except as you have two non-channel/sample dimensions, you'd do axes = [0, 1, 2]
(to collapse samples, height, width
).
Pseudocode demo: see mean computation in action
import tensorflow as tf
import tensorflow.keras.backend as K
import numpy as np
x = tf.constant(np.random.randn(8, 10, 4))
result1 = tf.add(x[0], tf.add(x[1], tf.add(x[2], tf.add(x[3], tf.add(x[4],
tf.add(x[5], tf.add(x[6], x[7]))))))) / 8
result2 = tf.reduce_mean(x, axis=0)
print(K.eval(result1 - result2))
# small differences per numeric imprecision
[[ 2.77555756e-17 0.00000000e+00 -5.55111512e-17 -1.38777878e-17]
[-2.77555756e-17 2.77555756e-17 0.00000000e+00 -1.38777878e-17]
[ 0.00000000e+00 -5.55111512e-17 0.00000000e+00 -2.77555756e-17]
[-1.11022302e-16 2.08166817e-17 2.22044605e-16 0.00000000e+00]
[ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]
[-5.55111512e-17 2.77555756e-17 -1.11022302e-16 5.55111512e-17]
[ 0.00000000e+00 0.00000000e+00 0.00000000e+00 -2.77555756e-17]
[ 0.00000000e+00 0.00000000e+00 0.00000000e+00 -5.55111512e-17]
[ 0.00000000e+00 -3.46944695e-17 -2.77555756e-17 1.11022302e-16]
[-5.55111512e-17 5.55111512e-17 0.00000000e+00 1.11022302e-16]]