Search code examples
python-3.xtensorflowcallbacktensorflow2.0learning-rate

How to change Learning rate in Tensorflow dependent on number of batches and epochs?


Is there a possibility to implement the following scenario with Tensorflow:

In the first N batches, the learning rate should be increased from 0 to 0.001. After this number of batches has been reached, the learning rate should slowly decrease from 0.001 to 0.00001 after each epoch.

How can I combine this combination in a callback? Tensorflow offers the tf.keras.callbacks.LearningRateScheduler and the callback functions on_train_batch_begin() or on_train_batch_end(). But I will not come to a common combination of these callbacks.

Can someone give me an approach how to create such a combined callback that depends on the number of batches and epochs?


Solution

  • Something like this would work. I didn't test this and I didn't try to perfect it...but the pieces are there so that you can get it working how you like.

    import tensorflow as tf
    from tensorflow.keras.callbacks import Callback
    import numpy as np
    
    class LRSetter(Callback):
        
        def __init__(self, start_lr=0, middle_lr=0.001, end_lr=0.00001, 
                     start_mid_batches=200, end_epochs=2000):
            
            self.start_mid_lr = np.linspace(start_lr, middle_lr, start_mid_batches)
            #Not exactly right since you'll have gone through a couple epochs
            #but you get the picture
            self.mid_end_lr = np.linspace(middle_lr, end_lr, end_epochs) 
            
            self.start_mid_batches = start_mid_batches
            
            self.epoch_takeover = False
            
        def on_train_batch_begin(self, batch, logs=None):
        
            if batch < self.start_mid_batches:
                tf.keras.backend.set_value(self.model.optimizer.lr, self.start_mid_lr[batch])
            else:
                self.epoch_takeover = True
    
        def on_epoch_begin(self, epoch):
            if self.epoch_takeover:
                tf.keras.backend.set_value(self.model.optimizer.lr, self.mid_end_lr[epoch])