I wonder, is there an easy way?
For example, changing learning rate can be easily done using tf.keras.optimizers.schedules
:
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(0.001)
optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule)
Is there an easy way to do the same with regularization factor? Like this:
r_schedule = tf.keras.optimizers.schedules.ExponentialDecay(0.1)
regularizer = tf.keras.regularizers.L2(l2=r_schedule)
If not, how can I gradually change regularization factor with minimal effort?
IIUC, I think you should be able to use a custom callback and implement the same / similar logic used by tf.keras.optimizers.schedules.ExponentialDecay
(but it could go beyond minimal effort):
import tensorflow as tf
class Decay(tf.keras.callbacks.Callback):
def __init__(self, l2, decay_steps, decay_rate, staircase):
super().__init__()
self.l2 = l2
self.decay_steps = decay_steps
self.decay_rate = decay_rate
self.staircase = staircase
def on_epoch_end(self, epoch, logs=None):
global_step_recomp = self.params.get('steps')
p = global_step_recomp / self.decay_steps
if self.staircase:
p = tf.floor(p)
self.l2.assign(tf.multiply(
self.l2, tf.pow(self.decay_rate, p)))
l2 = tf.Variable(initial_value=0.01, trainable=False)
def l2_regularizer(weights):
tf.print(l2)
loss = l2 * tf.reduce_sum(tf.square(weights))
return loss
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(1, kernel_regularizer=l2_regularizer))
model.compile(optimizer='adam', loss='mse')
model.fit(tf.random.normal((50,1 )), tf.random.normal((50,1 )), batch_size=4, callbacks=[Decay(l2,
decay_steps=100000,
decay_rate=0.56,
staircase=False)], epochs=3)
Epoch 1/3
0.01
1/13 [=>............................] - ETA: 8s - loss: 0.63850.01
0.01
0.01
0.01
0.01
0.01
0.01
0.01
9/13 [===================>..........] - ETA: 0s - loss: 2.13940.01
0.01
0.01
0.01
13/13 [==============================] - 1s 6ms/step - loss: 2.4884
Epoch 2/3
0.00999924541
1/13 [=>............................] - ETA: 0s - loss: 1.97210.00999924541
0.00999924541
0.00999924541
0.00999924541
0.00999924541
0.00999924541
0.00999924541
0.00999924541
9/13 [===================>..........] - ETA: 0s - loss: 2.37490.00999924541
0.00999924541
0.00999924541
0.00999924541
13/13 [==============================] - 0s 7ms/step - loss: 2.4541
Epoch 3/3
0.00999849103
1/13 [=>............................] - ETA: 0s - loss: 0.81400.00999849103
0.00999849103
0.00999849103
0.00999849103
0.00999849103
0.00999849103
7/13 [===============>..............] - ETA: 0s - loss: 2.71970.00999849103
0.00999849103
0.00999849103
0.00999849103
0.00999849103
0.00999849103
13/13 [==============================] - 0s 10ms/step - loss: 2.4195
<keras.callbacks.History at 0x7f7a5a4ff5d0>