Search code examples
pythontensorflowkerasconv-neural-networklearning-rate

tf.Keras learning rate schedules—pass to optimizer or callbacks?


I just wanted to set up a learning rate schedule for my first CNN and I found there are various ways of doing so:

  1. One can include the schedule in callbacks using tf.keras.callbacks.LearningRateScheduler()
  2. One can pass it to an optimizer using tf.keras.optimizers.schedules.LearningRateSchedule()

Now I wondered if there are any differences and if so, what are they? In case it makes no difference, why do those alternatives exist then? Is there a historical reason (and which method should be preferred)?

Can someone elaborate?


Solution

  • Both tf.keras.callbacks.LearningRateScheduler() and tf.keras.optimizers.schedules.LearningRateSchedule() provide the same functionality i.e to implement a learning rate decay while training the model.

    A visible difference could be that tf.keras.callbacks.LearningRateScheduler takes in a function in its constructor, as mentioned in the docs,

    tf.keras.callbacks.LearningRateScheduler(schedule, verbose=0)
    

    schedule: a function that takes an epoch index (integer, indexed from 0) and current learning rate (float) as inputs and returns a new learning rate as output (float).

    The schedule function will return a learning rate given the current epoch index. To implement various types of LR decays like the Exponential Decay, Polynomial Decay etc., you need to code them in this schedule method on your own.

    On the other hand, tf.keras.optimizers.schedules.LearningRateSchedule() is a high-level class. Other types of decay included in tf.keras.optimizers.schedules.* like the PolynomialDecay or InverseTimeDecay inherit this class. Hence this module offers builtin LR decay methods which are commonly used in ML. Moreover, to implement a custom LR decay, your class needs to inherit tf.keras.optimizers.schedules.LearningRateSchedule() and override methods like __call__ and __init__, as mentioned in the docs,

    To implement your own schedule object, you should implement the call method, which takes a step argument (scalar integer tensor, the current training step count).

    Conclusion:

    • If you want to use some built-in LR Decay, use tf.keras.optimizers.schedules.* modules i.e. the LR decays provided in that module.

    • If you need a simple custom LR decay which would only require the epoch index as an argument, use tf.keras.callbacks.LearningRateScheduler.

    • If your custom LR Decay needs more argument than just the epoch index, create a new class and inherit tf.keras.optimizers.schedules.LearningRateSchedule.