Search code examples
tensorflowpython-3.5mnist

Is there any way to only import the MNIST images with 0's and 1's?


I am just starting out with tensorflow and I want to test something only on the 0's and 1's from the MNIST images. Is there a way to import only these images?


Solution

  • Assuming you are using

    from tensorflow.examples.tutorials.mnist import input_data

    No, there is no function or argument in that file... What you can do is load all data, and select only the ones and zeros.

    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    train_mask = np.isin(y_train, [0, 1])
    x = x_train[train_mask]
    y = y_train[train_mask]