Search code examples
tensorflowgoogle-colaboratory

How to feed tensorflow an image from a url?


The five lines commented out below should work but do not . The prediction score is not anywhere close to what I would expect and when I do plt.imshow(img) it shows the wrong image. Here is the link to my notebook in Colab.

x, y = next(valid_generator)
image = x[0, :, :, :]
true_index = np.argmax(y[0])
plt.imshow(image)

image_url = 'https://mysite_example/share/court3.jpg'
image_url = tf.keras.utils.get_file('Court', origin=image_url )

#img = keras.preprocessing.image.load_img( image_url, target_size=( 224, 224 ) )
#img_array = keras.preprocessing.image.img_to_array(img)
#img_array = tf.expand_dims(img_array, 0) 
#prediction_scores = model.predict(np.expand_dims(img_array, axis=0))
#plt.imshow(img)

# Expand the validation image to (1, 224, 224, 3) before predicting the label
prediction_scores = model.predict(np.expand_dims(image, axis=0))
predicted_index = np.argmax(prediction_scores)
print("True label: " + get_class_string_from_index(true_index))
print("Predicted label: " + get_class_string_from_index(predicted_index)

Solution

  • The method tf.keras.utils.get_file downloads the file from url to local cache only if the file is not already cached. So if you are using the same cache name for all the urls ("Court" in your code ?) the you will see only the first file.

    Also while training you have a preprocess step of normalizing all the pixels by dividing them with 255. You have to apply the same preprocessing step during inference also.

    Working Code:

    _, axis = plt.subplots(1,3) 
    
    for i, image_url in enumerate(['https://squashvideo.site/share/court3.jpg',
                                   'https://i.pinimg.com/originals/0f/c2/9b/0fc29b35532f8e2fb998f5605212ab27.jpg',
                                   'https://thumbs.dreamstime.com/b/squash-court-photo-empty-30346175.jpg']):  
      image_url = tf.keras.utils.get_file('Court', origin=image_url )
      img = tf.keras.preprocessing.image.load_img(image_url, target_size=( 224, 224 ) )
      os.remove(image_url) # Remove the cached file
      axis[i].imshow(img)
    
      img_array = keras.preprocessing.image.img_to_array(img)
      prediction_scores = model.predict(np.expand_dims(img_array, axis=0)/255)
      axis[i].title.set_text(np.argmax(prediction_scores, axis=1))
    

    Output: enter image description here

    As you can see, the predictions are perfect, the last image belong to class 0 (empty squash court) and the second image belong to class 1 (players playing in squash court)