I have a Tensorflow model that was trained using the Estimator API. Now I want to load this model, add a new layer to it, and train only the new layer (i.e. freezing all other parameters).
What would be the best way to do that?
I managed to load the model from a checkpoint and add the new layer, but it's not clear to me how to feed examples as input and perform training. Specifically, I couldn't locate a placeholder for the input.
I have found a simple way to do that, and posting it here in case someone will find it useful:
assign
ops for every variable.Code:
# Load the trained model from checkpoint.
new_saver = tf.train.import_meta_graph('{}.meta'.format(config.ckpt_fullpath))
new_saver.restore(sess, config.ckpt_fullpath)
# Create new graph with a placeholder for input.
new_model_scope = 'new_scope'
trained_model_scope = 'old_scope' # this should be taken from the original model function of the estimator.
with tf.name_scope(new_model_scope):
model = Model(config)
input_tensor = tf.placeholder(tf.float32,
[None, config.img_size[0], config.img_size[1], 3])
model.build_model(input_tensor)
# Initialize the new graph variables with trained parameters.
trained_params = [t for t in tf.trainable_variables()
if t.name.startswith(trained_model_scope)]
trained_params = sorted(trained_params, key=lambda v: v.name)
new_params = [t for t in tf.trainable_variables()
if t.name.startswith(new_model_scope)]
new_params = sorted(new_params, key=lambda v: v.name)
update_ops = []
for trained_v, new_v in zip(trained_params, new_params):
op = new_v.assign(trained_v)
update_ops.append(op)
sess.run(update_ops)