Search code examples
pythontensorflowkerasregularized

Changing regularization factor during training in Tensorflow


I wonder, is there an easy way?

For example, changing learning rate can be easily done using tf.keras.optimizers.schedules:

lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(0.001)
optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule)

Is there an easy way to do the same with regularization factor? Like this:

r_schedule = tf.keras.optimizers.schedules.ExponentialDecay(0.1)
regularizer = tf.keras.regularizers.L2(l2=r_schedule)

If not, how can I gradually change regularization factor with minimal effort?


Solution

  • IIUC, I think you should be able to use a custom callback and implement the same / similar logic used by tf.keras.optimizers.schedules.ExponentialDecay (but it could go beyond minimal effort):

    import tensorflow as tf
    
    class Decay(tf.keras.callbacks.Callback):
    
      def __init__(self, l2, decay_steps, decay_rate, staircase):
        super().__init__()
        self.l2 = l2
        self.decay_steps = decay_steps
        self.decay_rate = decay_rate
        self.staircase = staircase
    
      def on_epoch_end(self, epoch, logs=None):
        global_step_recomp = self.params.get('steps')
        p = global_step_recomp / self.decay_steps
        if self.staircase:
          p = tf.floor(p)
        self.l2.assign(tf.multiply(
            self.l2, tf.pow(self.decay_rate, p)))
         
    l2 = tf.Variable(initial_value=0.01, trainable=False)
    
    def l2_regularizer(weights):
        tf.print(l2)
        loss = l2 * tf.reduce_sum(tf.square(weights))
        return loss
    
    model = tf.keras.Sequential()
    model.add(tf.keras.layers.Dense(1, kernel_regularizer=l2_regularizer))
    model.compile(optimizer='adam', loss='mse')
    model.fit(tf.random.normal((50,1 )), tf.random.normal((50,1 )), batch_size=4, callbacks=[Decay(l2,
        decay_steps=100000,
        decay_rate=0.56,
        staircase=False)], epochs=3)
    
    Epoch 1/3
    0.01
     1/13 [=>............................] - ETA: 8s - loss: 0.63850.01
    0.01
    0.01
    0.01
    0.01
    0.01
    0.01
    0.01
     9/13 [===================>..........] - ETA: 0s - loss: 2.13940.01
    0.01
    0.01
    0.01
    13/13 [==============================] - 1s 6ms/step - loss: 2.4884
    Epoch 2/3
    0.00999924541
     1/13 [=>............................] - ETA: 0s - loss: 1.97210.00999924541
    0.00999924541
    0.00999924541
    0.00999924541
    0.00999924541
    0.00999924541
    0.00999924541
    0.00999924541
     9/13 [===================>..........] - ETA: 0s - loss: 2.37490.00999924541
    0.00999924541
    0.00999924541
    0.00999924541
    13/13 [==============================] - 0s 7ms/step - loss: 2.4541
    Epoch 3/3
    0.00999849103
     1/13 [=>............................] - ETA: 0s - loss: 0.81400.00999849103
    0.00999849103
    0.00999849103
    0.00999849103
    0.00999849103
    0.00999849103
     7/13 [===============>..............] - ETA: 0s - loss: 2.71970.00999849103
    0.00999849103
    0.00999849103
    0.00999849103
    0.00999849103
    0.00999849103
    13/13 [==============================] - 0s 10ms/step - loss: 2.4195
    <keras.callbacks.History at 0x7f7a5a4ff5d0>