Search code examples
python-3.xtensorflowtensorflow-datasets

Plotting images of data from TensorFlow datasets, shows one image only


I was working on the Fashion MNIST dataset on TensorFlow, I was trying to plot the image of the train_data of having a specific test_label. I ran the following code, it works but it only shows me one image even though there are many such images.

for i in range (len(test_data)):
  if test_labels[i]==9:
    plt.imshow(test_data[i])

following is the output I am getting:

enter image description here


Solution

  • It overwrites the previous plot each time you iterate over test data.

    If you want to get different plots, use something like this:

    j = 0
    n_rows = 2
    n_cols = 3
    for i in range(len(test_data)):
        if test_labels[i]==9:
            j += 1
            ax = plt.subplot(n_rows,n_cols,j)
            ax.imshow(test_data[i])
        if j >= (n_rows*n_cols):
            break
    

    The result: plotted image