Search code examples
python-3.xkerasfiltermnist

How can I remove a specific digit from the MNIST dataset imported from keras?


I'm trying to remove a specific digit(like 0, 4) from MNIST dataset provided by Keras.

from keras.datasets import mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train.drop([0], axis = 1)
y_train.drop([0], axis = 1)

x_train_0 = x_trian[0]
y_train_0 = y_train[0]

It turns out an error: AttributeError: 'numpy.ndarray' object has no attribute 'drop'

what should I do?

What's more, if I want to subtract the data of digit 0, can I simply do x_train[0]?

Thank you!!!


Solution

  • Let's look at the Keras MNIST data format first.

    >>> from keras.datasets import mnist
    >>> (x_train, y_train), (x_test, y_test) = mnist.load_data()
    >>> x_train.shape
    (60000, 28, 28)
    >>> y_train.shape
    (60000,)
    

    So the x_... variables hold the images, and the y_... variables hold the labels. They are both Numpy array objects. What order are the data in?

    >>> y_train[:20]
    array([5, 0, 4, 1, 9, 2, 1, 3, 1, 4, 3, 5, 3, 6, 1, 7, 2, 8, 6, 9],
          dtype=uint8)
    

    They're in a RANDOM order. This is a good thing if you want to take a small fraction of the data just by taking a contiguous slice, it is easy to get a sample which includes every digit. But it makes the task you want to do harder. You need the indices which correspond to each digit. Then you need to use those indices to select both the images and the labels you want.

    You want to look at the Numpy array method called nonzero(), and you will want to understand how Numpy uses arrays of Boolean values for selecting elements from arrays with compatible shapes. This two-line function will do what you need:

    def remove(digit, x, y):
        idx = (y != digit).nonzero()
        return x[idx], y[idx]
    

    And here's an example of how to call it:

    x_no3, y_no3 = remove(3, x_train, y_train)