Search code examples
pythontensorflowgcloudgoogle-predictiongoogle-cloud-ml-engine

Tensorflow: Resize Image Placeholder


I have a trained TF model that that operates on a serialized (TFRecord) input. The image data has variable shape and is converted to a 229x229x3 shape via tf.image.resize_images(...). I would like to use the gcloud ml-engine predict platform similar to this, making sure to accept any size image as input.

I get my features tensor (which is passed to the prediction graph) from the following function:

def jpeg_serving_input_fn():
  """
  Serve single jpeg feature to the prediction graph
  :return: Image as a tensor
  """
  input_features = tf.placeholder(dtype=tf.float32, shape=[None, None, 3], 
                                  name="PREDICT_PLACEHOLDER")
  features_normalized = tf.image.resize_images(input_features, [229, 229])

  image = tf.reshape(features_normalized, [1, 229, 229, 3], name="RESHAPE_PREDICT")

  inputs = {
    'image': image
  }

The tf.reshape at the end is because my prediction graph expects a tensor of shape [batch_size, 229, 229, 3]. When I run this through the engine via

gcloud ml-engine local predict \
--model-dir=trained_model/export/ \
--json-instances=img.json

I get a PredictionError:

predict_lib_beta.PredictionError: (4, "Exception during running the graph: Cannot feed value of shape (1, 1600, 2400, 3) for Tensor u'RESHAPE_PREDICT:0', which has shape '(1, 229, 229, 3)'")

It looks to me like tf.reshape is being fed the output of tf.image.resize_images which should have the correct shape. Any thoughts on what I'm doing wrong here? Thanks in advance!


Solution

  • It looks like the error is caused by some code that feeds the "RESHAPE_PREDICT:0" tensor (i.e. the output of the tf.reshape() op, image) rather than the "PREDICT_PLACEHOLDER:0" tensor (i.e. the input to the tf.image.resize_images() op, input_features).

    Without the whole source to your trained model, it's hard to say exactly what changes are necessary, but it might be as simple as changing the definition of inputs to:

    inputs = {'image': input_features}
    

    ...so that the prediction service knows to feed values to that placeholder, rather than the fixed-shape output of tf.reshape().