Search code examples
pythontensorflowkerasadam

Resume training with Adam optimizer in Keras


My question is quite straightforward but I can't find a definite answer online (so far).

I have saved the weights of a keras model trained with an adam optimizer after a defined number of epochs of training using:

callback = tf.keras.callbacks.ModelCheckpoint(filepath=path, save_weights_only=True)
model.fit(X,y,callbacks=[callback])

When I resume training after closing my jupyter, can I simply use:

model.load_weights(path)

to continue training.

Since Adam is dependent on the epoch number (such as in the case of learning rate decay), I would like to know the easiest way to resume training in the same conditions as before.

Following ibarrond's answer, I have written a small custom callback.

optim = tf.keras.optimizers.Adam()
model.compile(optimizer=optim, loss='categorical_crossentropy',metrics=['accuracy'])

weight_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path, save_weights_only=True, verbose=1, save_best_only=False)

class optim_callback(tf.keras.callbacks.Callback):
    '''Custom callback to save optimiser state'''

          def on_epoch_end(self,epoch,logs=None):
                optim_state = tf.keras.optimizers.Adam.get_config(optim)
                with open(optim_state_pkl,'wb') as f_out:                  
                       pickle.dump(optim_state,f_out)

model.fit(X,y,callbacks=[weight_callback,optim_callback()])

When I resume training:

model.load_weights(checkpoint_path)
with open(optim_state_pkl,'rb') as f_out:                  
                    optim_state = pickle.load(f_out)
tf.keras.optimizers.Adam.from_config(optim_state)

I would just like to check if this is correct. Many thanks again!!

Addendum: On further reading of the default Keras implementation of Adam and the original Adam paper, I believe that the default Adam is not dependent on epoch number but only on the iteration number. Therefore, this is unnecessary. However, the code may still be useful for anyone who wishes to keep track of other optimisers.


Solution

  • In order to perfectly capture the status of your optimizer, you should store its configuration using the function get_config(). This function returns a dictionary (containing the options) that can be serialized and stored in a file using pickle.

    To restart the process, just d = pickle.load('my_saved_tfconf.txt') to retrieve the dictionary with the configuration and then generate your Adam Optimizer using the function from_config(d) of the Keras Adam Optimizer.