Search code examples
python-3.xmatplotlibkerastypeerrormnist

Invalid dimension for image data in plt.imshow()


I am using mnist dataset for training a capsule network in keras background. After training, I want to display an image from mnist dataset. For loading images, mnist.load_data() is used. The data is stored as (x_train, y_train),(x_test, y_test). Now, for visualizing image, my code is as follows:

img_path = x_test[1]  
print(img_path.shape)
plt.imshow(img_path)
plt.show()

The code gives output as follows:

(28, 28, 1)

and the error on plt.imshow(img_path) as follows:

TypeError: Invalid dimensions for image data

How to show image in png format. Help!


Solution

  • You can use tf.squeeze for removing dimensions of size 1 from the shape of a tensor.

    plt.imshow( tf.shape( tf.squeeze(x_train) ) )
    

    Check out TF2.0 example