Search code examples
pythonmatplotlibmachine-learningtensorflow-datasetstensorflow2.x

How to show an image of a certain class with predictions in a multiclass classification problem?


I've been working on a multiclass classification problem, where I need to make a function to show an image of a certain class of the fashion MNIST dataset and make a prediction on it. For example, plot 3 images of the T-shirt class with their predictions. I have tried different things but no success yet. I'm missing a conditional statement and I can't figure out how & where to implement it in my function.

This is what I've come up with so far:

# Make function to plot image
def plot_image(indx, predictions, true_labels, target_images):
  """
  Picks an image, plots it and labels it with a predicted and truth label.

  Args:
  indx: index number to find the image and its true label.
  predictions: model predictions on test data (each array is a predicted probability of values between 0 to 1).
  true_labels: array of ground truth labels for images.
  target_images: images from the test data (in tensor form).

  Returns:
  A plot of an image from `target_images` with a predicted class label
  as well as the truth class label from `true_labels`.
  """
  # Set target image
  target_image = target_images[indx]
  # Truth label
  true_label = true_labels[indx]
  # Predicted label
  predicted_label = np.argmax(predictions)  # find the index of max value

  # Show image
  plt.imshow(target_image, cmap=plt.cm.binary)
  plt.xticks([])
  plt.yticks([])

  # Set colors for right or wrong predictions
  if predicted_label == true_label:
    color = 'green'
  else:
    color = 'red'

  # Labels appear on the x-axis along with accuracy %
  plt.xlabel("{} {:2.0f}% ({})".format(class_names[predicted_label],
                                100*np.max(predictions),
                                class_names[true_label]),
                                color=color)



# Function to display image of a class
def display_image(class_indx):
  # Set figure size
  plt.figure(figsize=(10,10))

  # Set class index
  class_indx = class_indx

  # Display 3 images
  for i in range(3):
    plt.subplot(1, 3, i+1)
    # plot_image function
    plot_image(indx=class_indx, predictions=y_probs[class_indx],
               true_labels=test_labels, target_images=test_images_norm)

These are the class names 'T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'. When I call the display function display_image() and pass the class index as an argument display_image(class_indx=15), I'm getting the same image and the same prediction 3 times (Notice my wrong approach, I'm passing index number instead it should be the class name). I need a function that takes a str (the class name) and displays 3 different predictions of that class. For instance, display_image('Dress') should return 3 images of Dress class along with its 3 different predictions that my model has made, Prediction#1 (65%), Prediction#2 (100%), Prediction#3 (87%) like so. Thanks!


Solution

  • I think you are really close into solving your problem. You just need to sample three samples from your category of interest. I guess that you have used a le = LabelEncoder() to encode your target vector. If yes, then you will have the classes like this: labels = list(le.classes_). Then I would do the following:

    def display_image(class_of_interest: str, nb_samples: int=3):
        plt.figure(figsize=(10,10))
        
        class_indx = class_names.index(class_of_interest)
        target_idx = np.where(true_labels==class_indx)[0]
        imgs_idx = np.random.choice(target_idx, nb_samples, replace=False)
    
        for i in range(nb_samples):
           plt.subplot(1, nb_samples, i+1)
    
           plot_image(indx=imgs_idx[i], 
                      predictions=y_probs[imgs_idx[i]],
                      true_labels=test_labels, 
                      target_images=test_images_norm)