Search code examples

NotImplementedError: Learning rate schedule must override get_config

I have created a custom schedule using tf.keras and I am encountering this error while saving the model:

NotImplementedError: Learning rate schedule must override get_config

The class looks like this:

class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):

    def __init__(self, d_model, warmup_steps=4000):
        super(CustomSchedule, self).__init__()

        self.d_model = d_model
        self.d_model = tf.cast(self.d_model, tf.float32)

        self.warmup_steps = warmup_steps

    def __call__(self, step):
        arg1 = tf.math.rsqrt(step)
        arg2 = step * (self.warmup_steps**-1.5)

        return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)

    def get_config(self):
        config = {

        base_config = super(CustomSchedule, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))


  • When you are using the custom subclass model, it is a bit tricky to save the model architecture. Instead, it is easier to use the Model.save_weights() for saving the weights only.

    If you change the code to this you will not see that error:

      def get_config(self):
        config = {
        'd_model': self.d_model,
        'warmup_steps': self.warmup_steps,
        return config