Search code examples
kerascallbackpicklegoogle-colaboratorytf.keras

tf.keras how to save ModelCheckPoint object


ModelCheckpoint can be used to save the best model based on a specific monitored metrics. So it obviously has information about the best metrics stored within its object. If you train on google colab for example, your instance can be killed without warning and you would lose this info after a long training session.

I tried to pickle the ModelCheckpoint object but got:

TypeError: can't pickle _thread.lock objects  

Such that i can reuse this same object when I bring my notebook back. Is there a good way to do this? You can try to reproduce by:

chkpt_cb = tf.keras.callbacks.ModelCheckpoint('model.{epoch:02d}-{val_loss:.4f}.h5',
                                              monitor='val_loss',
                                              verbose=1,
                                              save_best_only=True)

with open('chkpt_cb.pickle', 'w') as f:
  pickle.dump(chkpt_cb, f, protocol=pickle.HIGHEST_PROTOCOL)

Solution

  • If callback object is not to be pickled (due to thread issue and not advisable), I can pickle this instead:

    best = chkpt_cb.best
    

    This stores the best monitored metrics that callback has seen, and it is a float, which you can pickle and reload next time, and then do this:

    chkpt_cb.best = best   # if chkpt_cb is a brand new object you create when colab killed your session. 
    

    This is my own setup:

    # All paths should be on Google Drive, I omitted it here for simplicity.
    
    chkpt_cb = tf.keras.callbacks.ModelCheckpoint(filepath='model.{epoch:02d}-{val_loss:.4f}.h5',
                                                  monitor='val_loss',
                                                  verbose=1,
                                                  save_best_only=True)
    
    if os.path.exists('chkpt_cb.best.pickle'):
      with open('chkpt_cb.best.pickle', 'rb') as f:
        best = pickle.load(f)
        chkpt_cb.best = best
    
    def save_chkpt_cb():
      with open('chkpt_cb.best.pickle', 'wb') as f:
        pickle.dump(chkpt_cb.best, f, protocol=pickle.HIGHEST_PROTOCOL)
    
    save_chkpt_cb_callback = tf.keras.callbacks.LambdaCallback(
        on_epoch_end=lambda epoch, logs: save_chkpt_cb()
    )
    
    history = model.fit_generator(generator=train_data_gen,
                              validation_data=dev_data_gen,
                              epochs=5,
                              callbacks=[chkpt_cb, save_chkpt_cb_callback])
    

    So even when your colab session got killed, you can still retrieve the last best metrics and inform your new instance about it, and continue training as usual. This especially help when you re-compile a stateful optimizer and may cause a regression in the loss/metric and don't want to save those models for first few epochs.