Search code examples
kerasbatch-normalization

Is the first axis 0 or 1 in keras.layers.BatchNormalization()?


I am working on the 'Keras_Tutorial_v2a' by Andrew Ng on Coursera and I am confused about the axis parameter in keras.layers.BatchNormalization().

The first few layers of the model are:

X_input = Input(input_shape)
X = Conv2D(32, (3, 3), strides = (1, 1), name = 'conv0')(X_input)
X = BatchNormalization(axis = 3, name = 'bn0')(X)

Where input_shape is the shape of the images of the dataset:(height, width, channels). So it seems that axis=3 is referring to the channels, but shouldn't that be axis=2? I couldn't find documentation specifying this, but usually in python indices and axes begin at 0.

So either axes begins at 1 in this function, or there is something I am missing. Can anyone clarify this for me please? I'm sure it's something simple!


Solution

  • In tutorials and Keras/TensorFlow codebase, you will see axis = 3 or axis = -1. This is what should be chosen, since the channel axis is 3 (or the last one, -1).

    If you look in the original documentation, the default is -1 (3rd in essence).

    https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization