Search code examples
sessiontensorflowcheckpoint

Tensorflow restore `tf.Session` saved checkpoint using `tf.train.MonitoredTrainingSession`


I have code for training a CNN using tf.train.MonitoredTrainingSession.

When I create a new tf.train.MonitoredTrainingSession I can pass the checkpoint directory as an input parameter to the session and it will automatically restore the latest saved checkpoint it could find. And I can set up the hooks to train until some step. For example, if the checkpoint's step is 150,000 and I would like to train until 200,000 I will put the last_step to 200,000.

The above process works perfectly as long as the latest checkpoint was saved using a tf.train.MonitoredTrainingSession. However, if I try to restore a checkpoint that was saved using a normal tf.Session then all hell breaks loose. It can't find some keys in the graph and all.

The training is done with this:

with tf.train.MonitoredTrainingSession(
    checkpoint_dir=FLAGS.retrain_dir,
    hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_training_steps),
           tf.train.NanTensorHook(loss),
           _LoggerHook()],
    config=tf.ConfigProto(
        log_device_placement=FLAGS.log_device_placement)) as mon_sess:
  while not mon_sess.should_stop():
    mon_sess.run(train_op)

If the checkpoint_dir attribute had a folder with no checkpoints, this will start all over. If it had a checkpoint that was saved from a previous training session, it will restore the latest checkpoint and will continue training.

Now, I am restoring the latest checkpoint and modifying some variables and saving them:

saver = tf.train.Saver(variables_to_restore)

ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)

with tf.Session() as sess:
  if ckpt and ckpt.model_checkpoint_path:
    # Restores from checkpoint
    saver.restore(sess, ckpt.model_checkpoint_path)
    print(ckpt.model_checkpoint_path)
    restored_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
    FLAGS.max_training_steps = int(restored_step) + FLAGS.max_training_steps
  else:
    print('No checkpoint file found')
    return

  prune_convs(sess)
  saver.save(sess, FLAGS.retrain_dir+"model.ckpt-"+restored_step)

As you can see, just before saver.save... I am pruning all convolutional layers in the network. No need to describe how and why is that done. The point is that the network is in fact modified. Then I save the network to a checkpoint.

Now, if I deploy test on the saved modified network, the test works just fine. However, when I try to run the tf.train.MonitoredTrainingSession on the checkpoint that was saved, it says:

Key conv1/weight_loss/avg not found in checkpoint

Also, I have noticed that the checkpoint that was saved with tf.Session has half of the size of the checkpoint that was saved with tf.train.MonitoredTrainingSession

I know I'm doing it wrong, any suggestions how to make this work?


Solution

  • I figured it out. Apparently, tf.Saver does not restore all variables from a checkpoint. I tried restoring and saving immediately and the output was half the size.

    I used tf.train.list_variables to get all variables from latest checkpoint and then converted them into tf.Variable and created a dict from them. Then I passed the dict to tf.Saver and it restored all of my variables.

    The next thing was to initialize all of the variables and then modify the weights.

    Now it is working.