Search code examples
python-3.xtensorflowpre-trained-model

How to run predictions on image using a pretrained tensorflow model?


I have adapted this retrain.py script to use with several pretraineds model, after training is done this generates a 'retrained_graph.pb' which I then read and try to use to run predictions on an image using this code:

def get_top_labels(image_data):
    '''
    Returns a list of labels and their probabilities
    image_data: content of image as string
    '''
    with tf.compat.v1.Session() as sess:
        softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')
        predictions = sess.run(softmax_tensor, {'DecodeJpeg/contents:0': image_data})
        return predictions

This works fine for inception_v3 model because it has a tensor called 'DecodeJpeg', other models I'm using such as inception_v4, mobilenet and inception_resnet_v2 don't.

My question is can I add an ops to the graph, like the one used in add_jpeg_decoding in the retrain.py script so that I can afterwards use that for prediction ?

Would it be possible to do something like this: predictions = sess.run(softmax_tensor, {image_data_tensor: image_data}) where image_data_tensor is a variable that depends on what model I'm using ?

I looked through stackoverflow and couldn't find a question that solves my problem, I'd really appreciate any help with this, thanks. I need to at least know if it's possible. Sorry for repost I got no views on my first one.


Solution

  • So after some research, I figured out a way, leaving an answer here in case someone needs it. What you need to do is do the decoding yourself get a tensor from the image using t = read_tensor_from_image_file found here, then you can run your predictions using this piece of code:

            start = time.time()
            results = sess.run(output_layer_name,
                               {input_layer_name: t})
            end = time.time()
            return results
    

    usually input_layer_name = input:0 and output_layer_name = final_result:0.