Search code examples
machine-learningneural-networkkerasdeep-learning

Does model.save() save the model of the last epoch or the best epoch?


This single liner is used to save the keras deep learning neural network model.

model.save('my_model.h5')

Does model.save() save the model of the last epoch or the best epoch? Sometimes, the last epoch does not provide improvement to performance.


Solution

  • It saves the model in its exact current state. If this statement is after the Model#fit method completion, then it represents the last epoch.

    For best epoch (assuming best == smallest loss or greater accuracy), you can use the ModelCheckpoint for this:

    epochs = 100
    # other parameters...
    
    model.fit(x, y,
              epochs=epochs,
              validation_data=valid,
              verbose=2,
              callbacks=[
                  TerminateOnNaN(),
                  TensorBoard('./logs'),
                  ModelCheckpoint('best.h5',
                                  save_best_only=True),
                  ...
              ])
    # the model is holding the weights optimized for 100 epochs.
    model.load_weights('best.h5')  # load weights that generated the min val loss.