I have a fit() function that uses the ModelCheckpoint() callback to save the model if it better than any previous model, using save_weights_only=False, so it saves the entire model. This should allow me to resume training at a later date by using load_model().
Unfortunately, somewhere in the save()/load_model() roundtrip, the metric values are not preserved -- for example, val_loss is set to inf. This means that when training resumes, after the first epoch ModelCheckpoint() will always save the model, which will almost always be worse than the previous champion from the earlier session.
I have determined that I can set ModelCheckpoint()'s current best value before resuming training, as follows:
myCheckpoint = ModelCheckpoint(...)
myCheckpoint.best = bestValueSoFar
Obviously, I could monitor the values I need, write them out to a file, and read them in again when I resume, but given that I am a Keras newbie, I am wondering if I have missed something obvious.
I ended up quickly writing my own callback that keeps track of the best training values so I can reload them. It looks like this:
# State monitor callback. Tracks how well we are doing and writes
# some state to a json file. This lets us resume training seamlessly.
#
# ModelState.state is:
#
# { "epoch_count": nnnn,
# "best_values": { dictionary with keys for each log value },
# "best_epoch": { dictionary with keys for each log value }
# }
class ModelState(callbacks.Callback):
def __init__(self, state_path):
self.state_path = state_path
if os.path.isfile(state_path):
print('Loading existing .json state')
with open(state_path, 'r') as f:
self.state = json.load(f)
else:
self.state = { 'epoch_count': 0,
'best_values': {},
'best_epoch': {}
}
def on_train_begin(self, logs={}):
print('Training commences...')
def on_epoch_end(self, batch, logs={}):
# Currently, for everything we track, lower is better
for k in logs:
if k not in self.state['best_values'] or logs[k] < self.state['best_values'][k]:
self.state['best_values'][k] = float(logs[k])
self.state['best_epoch'][k] = self.state['epoch_count']
with open(self.state_path, 'w') as f:
json.dump(self.state, f, indent=4)
print('Completed epoch', self.state['epoch_count'])
self.state['epoch_count'] += 1
Then, in the fit() function, something like this:
# Set up the model state, reading in prior results info if available
model_state = ModelState(path_to_state_file)
# Checkpoint the model if we get a better result
model_checkpoint = callbacks.ModelCheckpoint(path_to_model_file,
monitor='val_loss',
save_best_only=True,
verbose=1,
mode='min',
save_weights_only=False)
# If we have trained previously, set up the model checkpoint so it won't save
# until it finds something better. Otherwise, it would always save the results
# of the first epoch.
if 'best_values' in model_state.state:
model_checkpoint.best = model_state.state['best_values']['val_loss']
callback_list = [model_checkpoint,
model_state]
# Offset epoch counts if we are resuming training. If you don't do
# this, only epochs-initial_epochs epochs will be done.
initial_epoch = model_state.state['epoch_count']
epochs += initial_epoch
# .fit() or .fit_generator, etc. goes here.