Search code examples
pythontensorflowmachine-learningkeraskeras-tuner

How can i use tf.keras.callbacks.ModelCheckpoint in Keras Tuner?


So i want to use tf.keras.callbacks.ModelCheckpoint in Keras Tuner, but The way you choose the path where to save the checkpoints, doesn't allow you to save it as a file with a certain name, a name associated to the trial and execution of that checkpoint, only associated to a epoch.

That is, if I simply put this callback in the Keras Tuner, at the moment the checkpoints save happens, in the end, I won't know how to associate the checkpoints saved to a trial and trial execution, only to epoch.


Solution

  • You can use tf.keras.callbacks.ModelCheckpoint for Keras tuner the same way as used in other model to save checkpoints.

    After training the model with the hyperparameters obtained from the search as per this model, you can define model checkpoints and save it as below:

    hypermodel = tuner.hypermodel.build(best_hps)
    
    # Retrain the model
    hypermodel.fit(img_train, label_train, epochs=best_epoch, validation_split=0.2)
    
    import os
    checkpoint_path = "training_1/cp.ckpt"
    checkpoint_dir = os.path.dirname(checkpoint_path)
    
    # Create a callback that saves the model's weights
    cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                     save_weights_only=True,
                                                     verbose=1)
    history = hypermodel.fit(img_train, label_train, epochs=5, validation_split=0.2, callbacks=[cp_callback])
    os.listdir(checkpoint_dir)
    
    # Re-evaluate the model
    loss, acc = hypermodel.evaluate(img_test, label_test, verbose=2)
    print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
    
    # Loads the weights
    hypermodel.load_weights(checkpoint_path)
    
    # Re-evaluate the model
    loss, acc = hypermodel.evaluate(img_test, label_test, verbose=2)
    print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
    

    Please refer this link for more inofrmation on save and load model checkpoints.