Search code examples
pythonkerasmnist

How to sperate Keras mnist dataset into 5 groups of (0 1), (2 3), (4 5), (6 7), (8 9)


I am new to machine learning. I am trying to train models on keras mnist dataset. But I want to train the models on the 5 groups sperately. Can someone please advise how to sperate the mnist dataset into the specified groups?

I have tried google for quite some time, but couldn't figure out how to do this.

Many thanks in advance!


Solution

  • How about using a for loop:

    from keras.datasets import mnist
    import numpy as np
    
    (train_images, train_labels), (test_images, test_labels) = mnist.load_data()
    images = np.concatenate((train_images, test_images), axis=0)
    labels = np.concatenate((train_labels, test_labels), axis=0)
    groups = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]
    images_and_labels_by_group = []
    for group in groups:
        indices = np.where(np.isin(labels, group))[0]
        group_images = images[indices]
        group_labels = labels[indices]
        images_and_labels_by_group.append((group_images, group_labels))