Search code examples
tensorflowtensorflow-estimator

Adding a layer to a trained Tensorflow estimator


I have a Tensorflow model that was trained using the Estimator API. Now I want to load this model, add a new layer to it, and train only the new layer (i.e. freezing all other parameters).

What would be the best way to do that?

I managed to load the model from a checkpoint and add the new layer, but it's not clear to me how to feed examples as input and perform training. Specifically, I couldn't locate a placeholder for the input.


Solution

  • I have found a simple way to do that, and posting it here in case someone will find it useful:

    1. Load the Estimator checkpoint.
    2. Create a placeholder and a copy of the model graph, under a new name scope.
    3. Extract all trainable variables under the two scopes.
    4. Create assign ops for every variable.

    Code:

        # Load the trained model from checkpoint.
        new_saver = tf.train.import_meta_graph('{}.meta'.format(config.ckpt_fullpath))
        new_saver.restore(sess, config.ckpt_fullpath)
    
        # Create new graph with a placeholder for input.
        new_model_scope = 'new_scope'
        trained_model_scope = 'old_scope' # this should be taken from the original model function of the estimator.
        with tf.name_scope(new_model_scope):
            model = Model(config)
            input_tensor = tf.placeholder(tf.float32, 
                                          [None, config.img_size[0], config.img_size[1], 3])
            model.build_model(input_tensor)
    
        # Initialize the new graph variables with trained parameters.
        trained_params = [t for t in tf.trainable_variables()
                          if t.name.startswith(trained_model_scope)]
        trained_params = sorted(trained_params, key=lambda v: v.name)
        new_params = [t for t in tf.trainable_variables() 
                      if t.name.startswith(new_model_scope)]
        new_params = sorted(new_params, key=lambda v: v.name)
    
        update_ops = []
        for trained_v, new_v in zip(trained_params, new_params):
            op = new_v.assign(trained_v)
            update_ops.append(op)
        sess.run(update_ops)