Search code examples
pythontensorflowkerascheckpoint

tf.keras.callbacks.ModelCheckpoint vs tf.train.Checkpoint


I am kinda new to TensorFlow world but have written some programs in Keras. Since TensorFlow 2 is officially similar to Keras, I am quite confused about what is the difference between tf.keras.callbacks.ModelCheckpoint and tf.train.Checkpoint. If anybody can shed light on this, I would appreciate it.


Solution

  • TensorFlow is a 'computation' library and Keras is a Deep Learning library which can work with TF or PyTorch, etc. So what TF provides is a more generic not-so-customized-for-deep-learning version. If you just compare the docs you can see how more comprehensive and customized ModelCheckpoint is. Checkpoint just reads and writes stuff from/to disk. ModelCheckpoint is much smarter!

    Also, ModelCheckpoint is a callback. It means you can just make an instance of it and pass it to the fit function:

    model_checkpoint = ModelCheckpoint(...)
    model.fit(..., callbacks=[..., model_checkpoint, ...], ...)
    

    I took a quick look at Keras's implementation of ModelCheckpoint, it calls either save or save_weights method on Model which in some cases uses TensorFlow's CheckPoint itself. So it is not a wrapper per se but certainly is on a lower level of abstraction -- more specialized for saving Keras models.