I have been unable to figure out how to use transfer learning/last layer retraining with the new TF Estimator API.
The Estimator
requires a model_fn
which contains the architecture of the network, and training and eval ops, as defined in the documentation. An example of a model_fn
using a CNN architecture is here.
If I want to retrain the last layer of, for example, the inception architecture, I'm not sure whether I will need to specify the whole model in this model_fn
, then load the pre-trained weights, or whether there is a way to use the saved graph as is done in the 'traditional' approach (example here).
This has been brought up as an issue, but is still open and the answers are unclear to me.
It is possible to load the metagraph during model definition and use SessionRunHook to load the weights from a ckpt file.
def model(features, labels, mode, params):
# Create the graph here
return tf.estimator.EstimatorSpec(mode,
predictions,
loss,
train_op,
training_hooks=[RestoreHook()])
The SessionRunHook can be:
class RestoreHook(tf.train.SessionRunHook):
def after_create_session(self, session, coord=None):
if session.run(tf.train.get_or_create_global_step()) == 0:
# load weights here
This way, the weights are loaded in first step and saved during training in model checkpoints.