Search code examples
tensorflowtransfer-learning

Is it possible to change the input shape of a tensorflow pretrained model?


I have a Tensorflow pre-trained model for Image Segmentation that receives 6 bands as input, I would like to change the input size of the model to receive 4 bands so I can retrain with my own dataset, but still not able to do it, no sure if this is even possible?

I tried getting the input node by name and change it using import_graph_def with no success, seems like it is asking to respect the dimensions when trying to substitute.

graph = tf.get_default_graph()
tf_new_input = tf.placeholder(shape=(4, 256, 256), dtype='float32', name='new_input')
tf.import_graph_def(graph_def, input_map={"ImageInputLayer": tf_new_input})

But I am getting the following error:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Dimensions must be equal, but are 4 and 6 for 'import/ImageInputLayer_Sub' (op: 'Sub') with input shapes: [4,256,256], [6,256,256]

Solution

  • You have to convert your 4 channel placeholder input to 6 channel input and also the input image shape should be the same as your 6 channel model expects. You may use any operation but conv2d is an easy operation to perform before you feed it to your existing model. This is how you do it.

    with tf.Graph().as_default() as old_graph:
      # You have to load your 6 channel input graph here
      saver.restore(tf.get_default_session(), <<save_path>>)
      # Assuming that input node is named as 'input_node' and 
      # final node is named as 'softmax_node'
    
    with tf.Graph().as_default() as new_graph:
      tf_new_input = tf.placeholder(shape=(None, 256, 256, 4), dtype='float32')
    
      # Map 4 channeled input to 6 channel and 
      # image input shape should be same as expected by old model.
      new_node = tf.nn.conv2d(tf_new_input, (3, 3, 4, 6), strides=1, padding='SAME')
    
      # If you want to obtain output node so that you can further perform operations.
      softmax_node = tf.import_graph_def(old_graph, input_map={'input_node:0': new_node}, 
                                         return_elements=['softmax_node:0'])