I have the following code.
x = keras.layers.Input(batch_shape = (None, 4096))
hidden = keras.layers.Dense(512, activation = 'relu')(x)
hidden = keras.layers.BatchNormalization()(hidden)
hidden = keras.layers.Dropout(0.5)(hidden)
predictions = keras.layers.Dense(80, activation = 'sigmoid')(hidden)
mlp_model = keras.models.Model(input = [x], output = [predictions])
mlp_model.summary()
And this is the model summary:
____________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
====================================================================================================
input_3 (InputLayer) (None, 4096) 0
____________________________________________________________________________________________________
dense_1 (Dense) (None, 512) 2097664 input_3[0][0]
____________________________________________________________________________________________________
batchnormalization_1 (BatchNorma (None, 512) 2048 dense_1[0][0]
____________________________________________________________________________________________________
dropout_1 (Dropout) (None, 512) 0 batchnormalization_1[0][0]
____________________________________________________________________________________________________
dense_2 (Dense) (None, 80) 41040 dropout_1[0][0]
====================================================================================================
Total params: 2,140,752
Trainable params: 2,139,728
Non-trainable params: 1,024
____________________________________________________________________________________________________
The size of the input for the BatchNormalization (BN) layer is 512. According to Keras documentation, shape of the output for BN layer is same as input which is 512.
Then how the number of parameters associated with BN layer is 2048?
The batch normalization in Keras implements this paper.
As you can read there, in order to make the batch normalization work during training, they need to keep track of the distributions of each normalized dimensions. To do so, since you are in mode=0
by default, they compute 4 parameters per feature on the previous layer. Those parameters are making sure that you properly propagate and backpropagate the information.
So 4*512 = 2048
, this should answer your question.