Search code examples
pythontensorflowmachine-learningkerasresnet

Loading a saved model to resume training


I'm training a ResNet model to classify car brands.

I saved the weights during training for every epoch.

For a test, I stopped the training at epoch 3.

# checkpoint = ModelCheckpoint("best_model.hdf5", monitor='loss', verbose=1)
checkpoint_path = "weights/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(
   checkpoint_path, verbose=1, 
   # Save weights, every epoch.
   save_freq='epoch')

model.save_weights(checkpoint_path.format(epoch=0))

history = model.fit_generator(
    training_set,
    validation_data = test_set,
    epochs = 50,
    steps_per_epoch = len(training_set),
    validation_steps = len(test_set),
    callbacks = [cp_callback]
)

However, when loading them, I am unsure if it is resuming from the last epoch saved one since it says epoch 1/50 again. Below is the code I use to load the last saved model.

from keras.models import Sequential, load_model
# load the model
new_model = load_model('./weights/cp-0003.ckpt')

# fit the model
history = new_model.fit_generator(
    training_set,
    validation_data = test_set,
    epochs = 50,
    steps_per_epoch = len(training_set),
    validation_steps = len(test_set),
    callbacks = [cp_callback]
)

This is what it looks like: Image showing that running the saved weight starts from epoch 1/50 again

Can someone please help?


Solution

  • You can use the initial_epoch argument of the fit_generator. By default, it is set to 0 but you can set it to any positive number:

    from keras.models import Sequential, load_model
    import tensorflow as tf
    
    checkpoint_path = "weights/cp-{epoch:04d}.ckpt"
    checkpoint_dir = os.path.dirname(checkpoint_path)
    
    cp_callback = tf.keras.callbacks.ModelCheckpoint(
       checkpoint_path, verbose=1, 
       # Save weights, every epoch.
       save_freq='epoch')
    
    model.save_weights(checkpoint_path.format(epoch=0))
    
    history = model.fit_generator(
        training_set,
        validation_data=test_set,
        epochs=3,
        steps_per_epoch=len(training_set),
        validation_steps=len(test_set),
        callbacks = [cp_callback]
    )
    
    
    new_model = load_model('./weights/cp-0003.ckpt')
    
    # fit the model
    history = new_model.fit_generator(
        training_set,
        validation_data=test_set,
        epochs=50,
        steps_per_epoch=len(training_set),
        validation_steps=len(test_set),
        callbacks=[cp_callback],
        initial_epoch=3
    )
    

    This will train your model for 50 - 3 = 47 additional epochs.


    Some remarks regarding your code if you use Tensorflow 2.X:

    • fit_generator is deprecated since fit supports generator now
    • you should replace your import from keras.... to from tensorflow.keras...