Search code examples
pythontensorflowkerasdeep-learningbatch-normalization

Changing BatchNormalization momentum while training in Tensorflow 2


I want batch normalization running statistics (mean and variance) to converge in the end of training, which requires to increase batch norm momentum from some initial value to 1.0. I managed to change momentum using a custom Callback, but it works only if my model is compiled in eager mode. Toy example (it sets momentum=1.0 after epoch zero due to which moving_mean should stop updating):

import tensorflow as tf  # version 2.3.1
import tensorflow_datasets as tfds

ds_train, ds_test = tfds.load("mnist", split=["train", "test"], shuffle_files=True, as_supervised=True)
ds_train = ds_train.batch(128)
ds_test = ds_test.batch(128)

model = tf.keras.models.Sequential(
    [
        tf.keras.layers.Flatten(input_shape=(28, 28)),
        tf.keras.layers.Dense(128),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.ReLU(),
        tf.keras.layers.Dense(10),
    ]
)


model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
    # run_eagerly=True,
)


class BatchNormMomentumCallback(tf.keras.callbacks.Callback):
    def on_epoch_begin(self, epoch, logs=None):
        last_bn_layer = None
        for layer in self.model.layers:
            if isinstance(layer, tf.keras.layers.BatchNormalization):
                if epoch == 0:
                    layer.momentum = 0.99
                else:
                    layer.momentum = 1.0
                last_bn_layer = layer
        if last_bn_layer:
            tf.print("Momentum=" + str(last_bn_layer.moving_mean[-1].numpy()))  # Should not change after epoch 1


batchnorm_decay = BatchNormMomentumCallback()
model.fit(ds_train, epochs=6, validation_data=ds_test, callbacks=[batchnorm_decay], verbose=0)

Output (get this when run_eagerly=False)

Momentum=0.0
Momentum=-102.20184
Momentum=-106.04614
Momentum=-116.36204
Momentum=-129.995
Momentum=-123.70443

Expected output (get it when run_eagerly=True)

Momentum=0.0
Momentum=-5.9038606
Momentum=-5.9038606
Momentum=-5.9038606
Momentum=-5.9038606
Momentum=-5.9038606

I guess this happens because in graph mode TF compiles the model as graph with a momentum defined as 0.99, and the uses this value in the graph (so momentum is not updated by BatchNormMomentumCallback).

Question: Is there a way to update that compiled momentum variable inside the graph while training? I want to update momentum not in eager mode (i.e. using run_eagerly=False) because training efficiency is important.


Solution

  • I would recommend simply using a custom training loop for your use case. You will have all the flexibility you need:

    import tensorflow as tf  # version 2.3.1
    import tensorflow_datasets as tfds
    
    ds_train, ds_test = tfds.load("mnist", split=["train", "test"], shuffle_files=True, as_supervised=True)
    ds_train = ds_train.batch(128)
    ds_test = ds_test.batch(128)
    
    model = tf.keras.models.Sequential(
        [
            tf.keras.layers.Flatten(input_shape=(28, 28)),
            tf.keras.layers.Dense(128),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.ReLU(),
            tf.keras.layers.Dense(10),
        ]
    )
    
    optimizer = tf.keras.optimizers.Adam(0.001)
    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
    batch_norm_layer = model.layers[2]
    
    @tf.function
    def train_step(epoch, model, batch):
        if epoch == 0:
            batch_norm_layer.momentum = 0.99
        else:
            batch_norm_layer.momentum = 1.0
    
        with tf.GradientTape() as tape:
            x_batch_train, y_batch_train = batch
    
            logits = model(x_batch_train, training=True)
            loss_value = loss_fn(y_batch_train, logits)
    
        train_acc_metric.update_state(y_batch_train, logits)
        grads = tape.gradient(loss_value, model.trainable_weights)
        optimizer.apply_gradients(zip(grads, model.trainable_weights))
    
    epochs = 6
    for epoch in range(epochs):
        tf.print("\nStart of epoch %d" % (epoch,))
        tf.print("Momentum = ", batch_norm_layer.moving_mean[-1], summarize=-1)
        for batch in ds_train:
            train_step(epoch, model, batch)
            
        train_acc = train_acc_metric.result()
        tf.print("Training acc over epoch: %.4f" % (float(train_acc),))
        train_acc_metric.reset_states()
    
    Start of epoch 0
    Momentum =  0
    Training acc over epoch: 0.9158
    
    Start of epoch 1
    Momentum =  -20.2749767
    Training acc over epoch: 0.9634
    
    Start of epoch 2
    Momentum =  -20.2749767
    Training acc over epoch: 0.9755
    
    Start of epoch 3
    Momentum =  -20.2749767
    Training acc over epoch: 0.9826
    
    Start of epoch 4
    Momentum =  -20.2749767
    Training acc over epoch: 0.9876
    
    Start of epoch 5
    Momentum =  -20.2749767
    Training acc over epoch: 0.9915
    

    A simple test shows that the function with the tf.function decorator performs way better:

    import tensorflow as tf  # version 2.3.1
    import tensorflow_datasets as tfds
    import timeit
    
    ds_train, ds_test = tfds.load("mnist", split=["train", "test"], shuffle_files=True, as_supervised=True)
    ds_train = ds_train.batch(128)
    ds_test = ds_test.batch(128)
    
    model = tf.keras.models.Sequential(
        [
            tf.keras.layers.Flatten(input_shape=(28, 28)),
            tf.keras.layers.Dense(128),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.ReLU(),
            tf.keras.layers.Dense(10),
        ]
    )
    
    optimizer = tf.keras.optimizers.Adam(0.001)
    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
    batch_norm_layer = model.layers[2]
    
    @tf.function
    def train_step(epoch, model, batch):
        if epoch == 0:
            batch_norm_layer.momentum = 0.99
        else:
            batch_norm_layer.momentum = 1.0
    
        with tf.GradientTape() as tape:
            x_batch_train, y_batch_train = batch
    
            logits = model(x_batch_train, training=True)
            loss_value = loss_fn(y_batch_train, logits)
    
        train_acc_metric.update_state(y_batch_train, logits)
        grads = tape.gradient(loss_value, model.trainable_weights)
        optimizer.apply_gradients(zip(grads, model.trainable_weights))
    
    def train_step_without_tffunction(epoch, model, batch):
        if epoch == 0:
            batch_norm_layer.momentum = 0.99
        else:
            batch_norm_layer.momentum = 1.0
    
        with tf.GradientTape() as tape:
            x_batch_train, y_batch_train = batch
    
            logits = model(x_batch_train, training=True)
            loss_value = loss_fn(y_batch_train, logits)
    
        train_acc_metric.update_state(y_batch_train, logits)
        grads = tape.gradient(loss_value, model.trainable_weights)
        optimizer.apply_gradients(zip(grads, model.trainable_weights))
    
    epochs = 6
    for epoch in range(epochs):
        tf.print("\nStart of epoch %d" % (epoch,))
        tf.print("Momentum = ", batch_norm_layer.moving_mean[-1], summarize=-1)
        test = True
        for batch in ds_train:
            train_step(epoch, model, batch)
            if test:
              tf.print("TF function:", timeit.timeit(lambda: train_step(epoch, model, batch), number=10))
              tf.print("Eager function:", timeit.timeit(lambda: train_step_without_tffunction(epoch, model, batch), number=10))
              test = False 
        train_acc = train_acc_metric.result()
        tf.print("Training acc over epoch: %.4f" % (float(train_acc),))
        train_acc_metric.reset_states()
    
    Start of epoch 0
    Momentum =  0
    TF function: 0.02285163299893611
    Eager function: 0.11109527599910507
    Training acc over epoch: 0.9229
    
    Start of epoch 1
    Momentum =  -88.1852188
    TF function: 0.024091466999379918
    Eager function: 0.1109461480009486
    Training acc over epoch: 0.9639
    
    Start of epoch 2
    Momentum =  -88.1852188
    TF function: 0.02331122400210006
    Eager function: 0.11751473100230214
    Training acc over epoch: 0.9756
    
    Start of epoch 3
    Momentum =  -88.1852188
    TF function: 0.02656845700039412
    Eager function: 0.1121610670015798
    Training acc over epoch: 0.9830
    
    Start of epoch 4
    Momentum =  -88.1852188
    TF function: 0.02821972700257902
    Eager function: 0.15709391699783737
    Training acc over epoch: 0.9877
    
    Start of epoch 5
    Momentum =  -88.1852188
    TF function: 0.02441513300072984
    Eager function: 0.10921925399816246
    Training acc over epoch: 0.9917