Search code examples
tensorflowtransferresnet

How do I use a pretrained network as a layer in Tensorflow?


I want to use a feature extractor (such as ResNet101) and add layers after that which use the output of the feature extractor layer. However, I can't seem to figure out how. I have only found solutions online where an entire network is used without adding additional layers. I am inexperienced with Tensorflow.

In the code below you can see what I have tried. I can run the code properly without the additional convolutional layer, however my goal is to add more layers after the ResNet. With this attempt at adding the extra conv layer, this type error is returned: TypeError: Expected float32, got OrderedDict([('resnet_v1_101/conv1', ...

Once I have added more layers, I would like to start training on a very small test set to see if my model can overfit.


import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow.contrib.slim.python.slim.nets import resnet_v1
import matplotlib.pyplot as plt

numclasses = 17

from google.colab import drive
drive.mount('/content/gdrive')

def decode_text(filename):
  img = tf.io.decode_jpeg(tf.io.read_file(filename))
  img = tf.image.resize_bilinear(tf.expand_dims(img, 0), [224, 224])
  img = tf.squeeze(img, 0)
  img.set_shape((None, None, 3))
  return img

dataset = tf.data.TextLineDataset(tf.cast('gdrive/My Drive/5LSM0collab/filenames.txt', tf.string))
dataset = dataset.map(decode_text)
dataset = dataset.batch(2, drop_remainder=True)

img_1 = dataset.make_one_shot_iterator().get_next()
net = resnet_v1.resnet_v1_101(img_1, 2048, is_training=False, global_pool=False, output_stride=8) 
net = slim.conv2d(net, numclasses, 1)


sess = tf.Session()

global_init = tf.global_variables_initializer()
local_init = tf.local_variables_initializer()
sess.run(global_init)
sess.run(local_init)
img_out, conv_out = sess.run((img_1, net))


Solution

  • resnet_v1.resnet_v1_101 does not return just net, but instead returns a tuple net, end_points. The second element is a dictionary, which is presumably why you are getting this particular error message.

    For the documentation of this function:

    Returns:

    net: A rank-4 tensor of size [batch, height_out, width_out, channels_out]. If global_pool is False, then height_out and width_out are reduced by a factor of output_stride compared to the respective height_in and width_in, else both height_out and width_out equal one. If num_classes is 0 or None, then net is the output of the last ResNet block, potentially after global average pooling. If num_classes a non-zero integer, net contains the pre-softmax activations.

    end_points: A dictionary from components of the network to the corresponding activation.

    So you can write for example:

    net, _ = resnet_v1.resnet_v1_101(img_1, 2048, is_training=False, global_pool=False, output_stride=8) 
    net = slim.conv2d(net, numclasses, 1)
    

    You can also choose an intermediate layer, e.g.:

    _, end_points = resnet_v1.resnet_v1_101(img_1, 2048, is_training=False, global_pool=False, output_stride=8) 
    net = slim.conv2d(end_points["main_Scope/resnet_v1_101/block3"], numclasses, 1)
    

    (you can look into end_points to find the names of the endpoints. Your scope name will be different than main_Scope.)