Search code examples
tensorflowpytorchtensorflow2.0batch-normalizationmomentum

How to dynamically update batch norm momentum in TF2?


I found a PyTorch implementation that decays the batch norm momentum parameter from 0.1 in the first epoch to 0.001 in the final epoch. Any suggestions on how to do this with the batch norm momentum parameter in TF2? (i.e., start at 0.9 and end at 0.999) For example, this is what is done in the PyTorch code:

# in training script
momentum = initial_momentum * np.exp(-epoch/args.epochs * np.log(initial_momentum/final_momentum))
model_pos_train.set_bn_momentum(momentum)

# model class function
def set_bn_momentum(self, momentum):
    self.expand_bn.momentum = momentum
    for bn in self.layers_bn:
        bn.momentum = momentum

SOLUTION:

The selected answer below provides a viable solution when using the tf.keras.Model.fit() API. However, I was using a custom training loop. Here is what I did instead:

After each epoch:

mi = 1 - initial_momentum  # i.e., inital_momentum = 0.9, mi = 0.1
mf = 1 - final_momentum  # i.e., final_momentum = 0.999, mf = 0.001
momentum = 1 - mi * np.exp(-epoch / epochs * np.log(mi / mf))
model = set_bn_momentum(model, momentum)

set_bn_momentum function (credit to this article):

def set_bn_momentum(model, momentum):
    for layer in model.layers:
        if hasattr(layer, 'momentum'):
            print(layer.name, layer.momentum)
            setattr(layer, 'momentum', momentum)

    # When we change the layers attributes, the change only happens in the model config file
    model_json = model.to_json()

    # Save the weights before reloading the model.
    tmp_weights_path = os.path.join(tempfile.gettempdir(), 'tmp_weights.h5')
    model.save_weights(tmp_weights_path)

    # load the model from the config
    model = tf.keras.models.model_from_json(model_json)

    # Reload the model weights
    model.load_weights(tmp_weights_path, by_name=True)
    return model

This method did not add significant overhead to the training routine.


Solution

  • You can set an action in the begin/the end of each batch, so you can control the any parameter during the epoch.

    Below the options for the callbacks:

    class CustomCallback(keras.callbacks.Callback):
        def on_epoch_begin(self, epoch, logs=None):
            keys = list(logs.keys())
            print("Start epoch {} of training; got log keys: {}".format(epoch, keys))
    
        def on_epoch_end(self, epoch, logs=None):
            keys = list(logs.keys())
            print("End epoch {} of training; got log keys: {}".format(epoch, keys))
    
        def on_train_batch_begin(self, batch, logs=None):
            keys = list(logs.keys())
            print("...Training: start of batch {}; got log keys: {}".format(batch, keys))
    
        def on_train_batch_end(self, batch, logs=None):
            keys = list(logs.keys())
            print("...Training: end of batch {}; got log keys: {}".format(batch, keys))
    
        def on_test_batch_begin(self, batch, logs=None):
            keys = list(logs.keys())
            print("...Evaluating: start of batch {}; got log keys: {}".format(batch, keys))
    
        def on_test_batch_end(self, batch, logs=None):
            keys = list(logs.keys())
            print("...Evaluating: end of batch {}; got log keys: {}".format(batch, keys))
    

    You can access the momentum

    batch = tf.keras.layers.BatchNormalization()
    batch.momentum = 0.001
    

    Inside the model you have to specified the correct layer

    model.layers[1].momentum = 0.001
    

    You can find more information and example at writing_your_own_callbacks