To find out how to implement both a ANN with exponential decay as well as a with a constant learning rate I looked it up here: https://www.tensorflow.org/api_docs/python/tf/compat/v1/train/exponential_decay
I have some questions:
...
global_step = tf.Variable(0, trainable=False)
starter_learning_rate = 0.1
learning_rate = tf.compat.v1.train.exponential_decay(starter_learning_rate,
global_step,
100000, 0.96, staircase=True)
# Passing global_step to minimize() will increment it at each step.
learning_step = (
tf.compat.v1.train.GradientDescentOptimizer(learning_rate)
.minimize(...my loss..., global_step=global_step)
)
When the global_step is set equal to a variable with the value 0 doesn't that mean that we will have no decay, since
decayed_learning_rate = learning_rate *
decay_rate ^ (global_step / decay_steps)
Therefore if global_step= 0
follows decayed_learning_rate = learning_rate
, is this right or am I making a mistake here?
Furthermore, I am a bit confused as to what exactly the 100,000 steps refer to. What exactly is one step? Is it every time an input has been fully fed through the network and backpropagated?
I hope this example clears your doubt.
epochs = 10
global_step = tf.Variable(0, trainable=False, dtype= tf.int32)
starter_learning_rate = 1.0
for epoch in range(epochs):
print("Starting Epoch {}/{}".format(epoch+1,epochs))
for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
with tf.GradientTape() as tape:
logits = model(x_batch_train, training=True)
loss_value = loss_fn(y_batch_train, logits)
grads = tape.gradient(loss_value, model.trainable_weights)
learning_rate = tf.compat.v1.train.exponential_decay(
starter_learning_rate,
global_step,
100000,
0.96
)
optimizer(learning_rate=learning_rate).apply_gradients(zip(grads, model.trainable_weights))
print("Global Step: {} Learning Rate: {} Examples Processed: {}".format(global_step.numpy(), learning_rate(), (step + 1) * 100))
global_step.assign_add(1)
Output:
Starting Epoch 1/10
Global Step: 0 Learning Rate: 1.0 Examples Processed: 100
Global Step: 1 Learning Rate: 0.9999996423721313 Examples Processed: 200
Global Step: 2 Learning Rate: 0.9999992251396179 Examples Processed: 300
Global Step: 3 Learning Rate: 0.9999988079071045 Examples Processed: 400
Global Step: 4 Learning Rate: 0.9999983906745911 Examples Processed: 500
Global Step: 5 Learning Rate: 0.9999979734420776 Examples Processed: 600
Global Step: 6 Learning Rate: 0.9999975562095642 Examples Processed: 700
Global Step: 7 Learning Rate: 0.9999971389770508 Examples Processed: 800
Global Step: 8 Learning Rate: 0.9999967217445374 Examples Processed: 900
Global Step: 9 Learning Rate: 0.9999963045120239 Examples Processed: 1000
Global Step: 10 Learning Rate: 0.9999958872795105 Examples Processed: 1100
Global Step: 11 Learning Rate: 0.9999954700469971 Examples Processed: 1200
Starting Epoch 2/10
Global Step: 12 Learning Rate: 0.9999950528144836 Examples Processed: 100
Global Step: 13 Learning Rate: 0.9999946355819702 Examples Processed: 200
Global Step: 14 Learning Rate: 0.9999942183494568 Examples Processed: 300
Global Step: 15 Learning Rate: 0.9999938607215881 Examples Processed: 400
Global Step: 16 Learning Rate: 0.9999934434890747 Examples Processed: 500
Global Step: 17 Learning Rate: 0.999993085861206 Examples Processed: 600
Global Step: 18 Learning Rate: 0.9999926686286926 Examples Processed: 700
Global Step: 19 Learning Rate: 0.9999922513961792 Examples Processed: 800
Global Step: 20 Learning Rate: 0.9999918341636658 Examples Processed: 900
Global Step: 21 Learning Rate: 0.9999914169311523 Examples Processed: 1000
Global Step: 22 Learning Rate: 0.9999909996986389 Examples Processed: 1100
Global Step: 23 Learning Rate: 0.9999905824661255 Examples Processed: 1200
Now if you keep your Global step as 0. ie Remove the increment operation from above code. Output:
Starting Epoch 1/10
Global Step: 0 Learning Rate: 1.0 Examples Processed: 100
Global Step: 0 Learning Rate: 1.0 Examples Processed: 200
Global Step: 0 Learning Rate: 1.0 Examples Processed: 300
Global Step: 0 Learning Rate: 1.0 Examples Processed: 400
Global Step: 0 Learning Rate: 1.0 Examples Processed: 500
Global Step: 0 Learning Rate: 1.0 Examples Processed: 600
Global Step: 0 Learning Rate: 1.0 Examples Processed: 700
Global Step: 0 Learning Rate: 1.0 Examples Processed: 800
Global Step: 0 Learning Rate: 1.0 Examples Processed: 900
Global Step: 0 Learning Rate: 1.0 Examples Processed: 1000
Global Step: 0 Learning Rate: 1.0 Examples Processed: 1100
Global Step: 0 Learning Rate: 1.0 Examples Processed: 1200
Starting Epoch 2/10
Global Step: 0 Learning Rate: 1.0 Examples Processed: 100
Global Step: 0 Learning Rate: 1.0 Examples Processed: 200
Global Step: 0 Learning Rate: 1.0 Examples Processed: 300
Global Step: 0 Learning Rate: 1.0 Examples Processed: 400
Global Step: 0 Learning Rate: 1.0 Examples Processed: 500
Global Step: 0 Learning Rate: 1.0 Examples Processed: 600
Global Step: 0 Learning Rate: 1.0 Examples Processed: 700
Global Step: 0 Learning Rate: 1.0 Examples Processed: 800
Global Step: 0 Learning Rate: 1.0 Examples Processed: 900
Global Step: 0 Learning Rate: 1.0 Examples Processed: 1000
Global Step: 0 Learning Rate: 1.0 Examples Processed: 1100
Global Step: 0 Learning Rate: 1.0 Examples Processed: 1200
Suggestion - Instead of using tf.compat.v1.train.exponential_decay use tf.keras.optimizers.schedules.ExponentialDecay. This is how the simplest example would look like.
def create_model1():
initial_learning_rate = 0.01
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
initial_learning_rate,
decay_steps=100000,
decay_rate=0.96,
staircase=True)
model = tf.keras.Sequential()
model.add(tf.keras.Input(shape=(5,)))
model.add(tf.keras.layers.Dense(units = 6,
activation='relu',
name = 'd1'))
model.add(tf.keras.layers.Dense(units = 2, activation='softmax', name = 'O2'))
model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=lr_schedule),
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
return model
model = create_model1()
model.fit(x, y, batch_size = 100, epochs = 100)
You can also use Callback like tf.keras.callbacks.LearningRateScheduler to implement your decay.