Search code examples
pythontensorflowpre-trained-model

Transfer learning/ retraining with TensorFlow Estimators


I have been unable to figure out how to use transfer learning/last layer retraining with the new TF Estimator API.

The Estimator requires a model_fn which contains the architecture of the network, and training and eval ops, as defined in the documentation. An example of a model_fn using a CNN architecture is here.

If I want to retrain the last layer of, for example, the inception architecture, I'm not sure whether I will need to specify the whole model in this model_fn, then load the pre-trained weights, or whether there is a way to use the saved graph as is done in the 'traditional' approach (example here).

This has been brought up as an issue, but is still open and the answers are unclear to me.


Solution

  • It is possible to load the metagraph during model definition and use SessionRunHook to load the weights from a ckpt file.

    def model(features, labels, mode, params):
        # Create the graph here
    
        return tf.estimator.EstimatorSpec(mode, 
                predictions,
                loss,
                train_op,
                training_hooks=[RestoreHook()])
    

    The SessionRunHook can be:

    class RestoreHook(tf.train.SessionRunHook):
    
        def after_create_session(self, session, coord=None):
            if session.run(tf.train.get_or_create_global_step()) == 0:
                # load weights here
    

    This way, the weights are loaded in first step and saved during training in model checkpoints.