Search code examples
tensorflowrestorecheckpoint

How continue train inception model from checkpoint in tensorflow


I have loaded pretrained inception model:

if FLAGS.pretrained_model_checkpoint_path: assert tf.gfile.Exists(FLAGS.pretrained_model_checkpoint_path) variables_to_restore = tf.get_collection( slim.variables.VARIABLES_TO_RESTORE) restorer = tf.train.Saver(variables_to_restore) restorer.restore(sess, FLAGS.pretrained_model_checkpoint_path) print('%s: Pre-trained model restored from %s' % (datetime.now(), FLAGS.pretrained_model_checkpoint_path)) And trained model on my data, by using flowers_train.py

After train completed, the loss was about 1.0 and the model was saved in specified directory.

Now I want to continue training, So, I restor model:

if FLAGS.checkpoint_dir is not None: # restoring from the checkpoint file ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) tf.train.Saver().restore(sess, ckpt.model_checkpoint_path)

And continue train model, but loss on first step is about 6.5, which in fact means, that model wasn't initialised at all.

Here is the whole content of inception_train.py, which were modified from this inception_train.py

First train I was start by:

bazel-bin/inception/flowers_train --train_dir="{$TRAIN_DIR}" --data_dir="{$DATA_DIR}" --fine_tune=True --initial_learning_rate=0.001 --input_queue_memory_factor=1 --batch_size=64 --max_steps=100 --pretrained_model_checkpoint_path="/home/tensorflow/inception-v3/model.ckpt-157585"

I have tried to continue training by this command:

bazel-bin/inception/flowers_train --train_dir="{$TRAIN_NEW_DIR}" --data_dir="{$DATA_DIR}" --fine_tune=False --initial_learning_rate=0.001 --input_queue_memory_factor=1 --batch_size=64 --max_steps=2000 --checkpoint_dir="{$TRAIN_DIR}"

Please, can anyone explain me, what I do wrong when initializing trained model?


Solution

  • I solved it by using the right arg_scope like follows:

    with slim.arg_scope(inception_v3.inception_v3_arg_scope()): logits, _ = inception_v3.inception_v3(eval_inputs, num_classes=1001, is_training=False)