For the convenience of discussion, the following models have been simplified.
Let's say there are around 40,000 512x512 images in my training set. I am trying to implement pre-training, and my plan is the following:
1.Train a neural network (lets call it net_1) that takes in 256x256 images, and save the trained model in tensorflow checkpoint file format.
net_1: input -> 3 conv2d -> maxpool2d -> 2 conv2d -> rmspool -> flatten -> dense
let's call this structure net_1_kernel:
net_1_kernel: 3 conv2d -> maxpool2d -> 3 conv2d
and call the remaining part other_layers:
other_layers: rmspool -> flatten -> dense
Thus we can represent net_1 in the following form:
net_1: input -> net_1_kernel -> other_layers
2.Insert several layers to the structure of net_1, and now call it net_2. It should look like this:
net_2: input -> net_1_kernel -> maxpool2d -> 3 conv2d -> other_layers
net_2 will take 512x512 images as input.
When I train net_2, I would like to use the saved weights and biases in the checkpoint file of net_1 to initialize the net_1_kernel part in net_2. How can I do this?
I know that I can load checkpoints to make predictions of test data. But in that case it will load everything (net_1_kernel and other_layers). What I want is to load net_1_kernel only and use it for the weight/bias initialization in net_2.
I also know that I can print contents in checkpoint files to txt, and copy & paste to manually initialize the weights and biases. However, there are so many numbers in those weights and biases, and this would be my last choice.
First of all, you can use the following code to check the list of all checkpoints in the ckpt file you saved.
from tensorflow.python.tools import inspect_checkpoint as chkp
chkp.print_tensors_in_checkpoint_file(file_name="file.ckpt", tensor_name="xxx", all_tensors=False, all_tensor_names=True)
Remember when you restore a checkpoint file, it will restore all variables in the checkpoint file. If you have to save and restore specific variables, you can do so as follows:
tf.trainable_variables()
var = [v for v in tf.trainable_variables() if "net_1_kernel" in v.name]
saverAndRestore = tf.train.Saver(var)
saverAndRestore.save(sess_1,"net_1.ckpt")
This will only save variables in the list var to net_1.ckpt.
saverAndRestore.restore(sess_1,"net_1.ckpt")
This will only restore variables in the list var from net_1.ckpt.
Apart from above, all you have to do is name/scope your variables such that you can easily do step 1 above.