Search code examples
pythonkerasdeep-learningtf.keras

Decay parameter of Adam optimizer in Keras


I think that Adam optimizer is designed such that it automtically adjusts the learning rate. But there is an option to explicitly mention the decay in the Adam parameter options in Keras. I want to clarify the effect of decay on Adam optimizer in Keras. If we compile the model using decay say 0.01 on lr = 0.001, and then fit the model running for 50 epochs, then does the learning rate get reduced by a factor of 0.01 after each epoch?

Is there any way where we can specify that the learning rate should decay only after running for certain number of epochs?

In pytorch there is a different implementation called AdamW, which is not present in the standard keras library. Is this the same as varying the decay after every epoch as mentioned above?

Thanks in advance for the reply.


Solution

  • From source code, decay adjusts lr per iterations according to

    lr = lr * (1. / (1. + decay * iterations))  # simplified
    

    see image below. This is epoch-independent. iterations is incremented by 1 on each batch fit (e.g. each time train_on_batch is called, or how many ever batches are in x for model.fit(x) - usually len(x) // batch_size batches).

    To implement what you've described, you can use a callback as below:

    from keras.callbacks import LearningRateScheduler
    def decay_schedule(epoch, lr):
        # decay by 0.1 every 5 epochs; use `% 1` to decay after each epoch
        if (epoch % 5 == 0) and (epoch != 0):
            lr = lr * 0.1
        return lr
    
    lr_scheduler = LearningRateScheduler(decay_schedule)
    model.fit(x, y, epochs=50, callbacks=[lr_scheduler])
    

    The LearningRateScheduler takes a function as an argument, and the function is fed the epoch index and lr at the beginning of each epoch by .fit. It then updates lr according to that function - so on next epoch, the function is fed the updated lr.

    Also, there is a Keras implementation of AdamW, NadamW, and SGDW, by me - Keras AdamW.



    Clarification: the very first call to .fit() invokes on_epoch_begin with epoch = 0 - if we don't wish lr to be decayed immediately, we should add a epoch != 0 check in decay_schedule. Then, epoch denotes how many epochs have already passed - so when epoch = 5, the decay is applied.