Search code examples
pythonmatplotlibmnist

Matplotlib hist method


I'm trying to plot a histogram of the MNIST dataset:

numberMNIST = fetch_openml('mnist_784',return_X_y=False)

dataset = numberMNIST.data     
labels = numberMNIST.target    

X_train, X_test, Y_train, Y_test = dataset[:60000], dataset[60000:], labels[:60000], labels[60000:]
Y_train_is4 = (Y_train == '4')
Y_test_is4 = (Y_test == '4')

plt.hist(Y_train)
plt.xlabel("Label")
plt.ylabel("Quantity")
plt.title("Labels in MNIST 784 dataset")

However, the order of labels is not properly organized:

enter image description here

How can I solve that, seeing the sequence 0,1,2,3,...,9 in X-axis?


Solution

  • hist function is more suitable for numeric variables. In your case, the labels are strings and are considered as object by numpy. I would rather use bar.

    import numpy as np
    
    unique, counts = np.unique(Y_train, return_counts=True)
    plt.bar(unique, counts)
    plt.xticks(unique)
    plt.xlabel("Label")
    plt.ylabel("Quantity")
    plt.title("Labels in MNIST 784 dataset")
    

    enter image description here