Search code examples
pythontensorflowimage-segmentationpre-trained-model

Tensorflow - transfer learning implementation (semantic segmentation)


I'm working on implementing a CNN architecture (FCN-8s model, with pretrained VGG16 model) for semantic segmentation on my own data (2 classes, therefore, a binary per-pixel classification)

How I intend to go about this is:

  1. Load the pre-trained model with weights
  2. Add/remove additional higher layers to convert to FCN
  3. Freeze lower layers of the pre-trained model (to not update during the training phase)
  4. Train the network on specific dataset

Assuming this is correct, how do I go about freezing the lower layers on my tensorflow model? (I'm looking for specific implementation details) I had a look at the Inception retraining on TensorFlow tutorial, but I'm not quite sure yet.

This is the workflow I have in mind:

  1. Run my data through the existing pretrained model, and extract the feature outputs, without training it. (how?)

  2. Feed these feature outputs into another network containing the higher layers - and go about training it.

Any suggestions would be helpful!

Else, if I'm wrong, how should I be thinking of this?

UPDATE:

I took up chasep255's suggestion below, and tried to use tf.stop_gradient so as to "freeze" the lower layers in my model. Clearly, there is something wrong with my implementation. Possible alternatives/suggestions?

The model is built based on the FCN (for semantic segmentation) paper. I extract logits from the model architecture, i.e., my features, that I initially feed directly into a loss function to minimize it with a softmax classifier. (per-pixel classification) deconv_1 is my logits tensor, of shape [batch, h, w, num_classes] = [1, 750, 750, 2] Implementation:

logits = vgg_fcn.deconv_1

stopper = tf.stop_gradient(logits, 'stop_gradients')

loss = train_func.loss(stopper, labels_placeholder, 2)

with tf.name_scope('Optimizer'):
    train_op = train_func.training(loss, FLAGS.learning_rate)

    with tf.name_scope('Accuracy'):
        eval_correct = train_func.accuracy_eval(logits, labels_placeholder)
        accuracy_summary = tf.scalar_summary('Accuracy', eval_correct)

I then run these Graph operations as below:

_, acc, loss_value = sess.run([train_op,eval_correct, loss], feed_dict=feed_dict)

When I run the training cycle thus, there is no optimization of the loss value, most definitely because of how I've introduced the tf.stop_gradient Op.

For more details, my loss function below:

def loss(logits, labels, num_classes):

    logits = tf.reshape(logits, [-1, num_classes])
    #epsilon = tf.constant(value=1e-4)
    #logits = logits + epsilon

    labels = tf.to_int64(tf.reshape(labels, [-1]))
    print ('shape of logits: %s' % str(logits.get_shape()))
    print ('shape of labels: %s' % str(labels.get_shape()))

    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, labels, name='Cross_Entropy')
    cross_entropy_mean = tf.reduce_mean(cross_entropy, name='xentropy_mean')
    tf.add_to_collection('losses', cross_entropy_mean)

    loss = tf.add_n(tf.get_collection('losses'), name='total_loss')
    return loss

Solution

  • You could just pass the output of the pretrained model into sess.run(pretrained_output, ...) and capture the output of the pretrained model. After you save the output you could then feed it into your model. In this case the optimizer would not be able to propagate the gradients to the pretrained model.

    You could also attach the pre trained model to you model normally and then pass the pretrained output through tf.stop_graidents() which would prevent the optimizer from propagating the gradients back into the pretrained model.

    Finally, you could just go through all the variables in the pretrained model and remove them from the list of trainable variables.