Search code examples
pythonmachine-learningscikit-learnknnmnist

Will this code work to recognise the MNIST set? (K-NN method)


I'm not sure if the following code will execute as it has been stuck on "Computing Prediction" for a long time. If it will not work what should I change?

import struct
import matplotlib.pyplot as plt
import numpy as np
import os
from scipy.special import expit
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score



clf = KNeighborsClassifier()

def load_data():
    with open('train-labels-idx1-ubyte', 'rb') as labels:
        magic, n = struct.unpack('>II', labels.read(8))
        train_labels = np.fromfile(labels, dtype=np.uint8)
    with open('train-images-idx3-ubyte', 'rb') as imgs:
        magic, num, nrows, ncols = struct.unpack('>IIII', imgs.read(16))
        train_images = np.fromfile(imgs, dtype=np.uint8).reshape(num, 784)
    with open('t10k-labels-idx1-ubyte', 'rb') as labels:
        magic, n = struct.unpack('>II', labels.read(8))
        test_labels = np.fromfile(labels, dtype=np.uint8)
    with open('t10k-images-idx3-ubyte', 'rb') as imgs:
        magic, num, nrows, ncols = struct.unpack('>IIII', imgs.read(16))
        test_images = np.fromfile(imgs, dtype=np.uint8).reshape(num, 784)
    return train_images, train_labels, test_images, test_labels


def knn(train_x, train_y, test_x, test_y):
    clf.fit(train_x, train_y)
    print("Compute predictions")
    predicted = clf.predict(test_x)
    print("Accuracy: ", accuracy_score(test_y, predicted))

train_x, train_y, test_x, test_y = load_data()
knn(train_x, train_y, test_x, test_y)

Solution

  • it has been stuck on "Computing Prediction" for a long time

    I recommend you use a very limited set of data to test if everything runs ok, before running it with the whole dataset. That way you ensure that the code makes sense.

    Once you have the code tested, you can safely proceed to train with the whole dataset.

    That way, you'll easily discern if the code takes long because of some code issue or just because of the amount of data (maybe the code is ok, but you may realize that for, say, 10 samples, it takes longer than you're willing to/can wait, so you can adjust accordingly - else it's too much of a black box what you're dealing with).

    Having said that, if the code is ok but it takes too long, I too suggest, as Soumya, to try running on Colab. You have some good hardware there, sessions up to 12 hours and the advantage of having your pc free to test other code in the meanwhile!