Search code examples
pythontensorflowkerasmnist

How to read and display MNIST dataset?


The code below opens the mnist dataset as a csv

import numpy as np
import csv
import matplotlib.pyplot as plt

with open('C:/Z_Uni/Individual_Project/Python_Projects/NeuralNet/MNIST_Dataset/mnist_train.csv/mnist_train.csv', 'r') as csv_file:
    for data in csv.reader(csv_file):
        # The first column is the label
        label = data[0]

        # The rest of columns are pixels
        pixels = data[1:]

        # Make those columns into a array of 8-bits pixels
        # This array will be of 1D with length 784
        # The pixel intensity values are integers from 0 to 255
        pixels = np.array(pixels, dtype='uint8')

        print(pixels.shape)
        # Reshape the array into 28 x 28 array (2-dimensional array)
        pixels = pixels.reshape((28, 28))
        print(pixels.shape)
        # Plot
        plt.title('Label is {label}'.format(label=label))
        plt.imshow(pixels, cmap='gray')
        plt.show()

        break # This stops the loop, I just want to see one

I got the code above from someone and cannot get it to display the mnist digits.

I get the error:

Traceback (most recent call last): File "C:\Z_Uni\Individual_Project\Python_Projects\NeuralNet\Test_View_Mnist.py", line 16, in pixels = np.array(pixels, dtype='uint8') ValueError: invalid literal for int() with base 10: '1x1'

When I remove dtype='unit8' I get the error:

Traceback (most recent call last): File "C:\Z_Uni\Individual_Project\Python_Projects\NeuralNet\Test_View_Mnist.py", line 24, in plt.imshow(pixels, cmap='gray') File "C:\Z_Uni\Individual_Project\Python_Projects\NeuralNet\source\lib\site-packages\matplotlib_api\deprecation.py", line 456, in wrapper return func(*args, **kwargs) File "C:\Z_Uni\Individual_Project\Python_Projects\NeuralNet\source\lib\site-packages\matplotlib\pyplot.py", line 2640, in imshow _ret = gca().imshow( File "C:\Z_Uni\Individual_Project\Python_Projects\NeuralNet\source\lib\site-packages\matplotlib_api\deprecation.py", line 456, in wrapper return func(*args, **kwargs) File "C:\Z_Uni\Individual_Project\Python_Projects\NeuralNet\source\lib\site-packages\matplotlib_init.py", line 1412, in inner return func(ax, *map(sanitize_sequence, args), **kwargs) File "C:\Z_Uni\Individual_Project\Python_Projects\NeuralNet\source\lib\site-packages\matplotlib\axes_axes.py", line 5488, in imshow

im.set_data(X)

File "C:\Z_Uni\Individual_Project\Python_Projects\NeuralNet\source\lib\site-packages\matplotlib\image.py", line 706, in set_data raise TypeError("Image data of dtype {} cannot be converted to " TypeError: Image data of dtype <U5 cannot be converted to float

Process finished with exit code 1

Could someone explain why this error is happening and how to fix it? Thanks.


Solution

  • There are two problems here. (1) You need to skip the first row because they are labels. (1x1), (1x2) and etc. (2) You need int64 data type. The code below will solve both. next(csvreader) skips the first row.

    import numpy as np
    import csv
    import matplotlib.pyplot as plt
    
    with open('./mnist_test.csv', 'r') as csv_file:
        csvreader = csv.reader(csv_file)
        next(csvreader)
        for data in csvreader:
            
            # The first column is the label
            label = data[0]
    
            # The rest of columns are pixels
            pixels = data[1:]
    
            # Make those columns into a array of 8-bits pixels
            # This array will be of 1D with length 784
            # The pixel intensity values are integers from 0 to 255
            pixels = np.array(pixels, dtype = 'int64')
            print(pixels.shape)
            # Reshape the array into 28 x 28 array (2-dimensional array)
            pixels = pixels.reshape((28, 28))
            print(pixels.shape)
            # Plot
            plt.title('Label is {label}'.format(label=label))
            plt.imshow(pixels, cmap='gray')
            plt.show()