Search code examples
pythonpytorchdatasetmnist

How to extract a specific digit from the MNIST dataset with dataloader?


I am feeding the MNIST dataset to train my neural network in the following manner

indices = torch.arange(60000)
dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)
datasetsmall = data_utils.Subset(dataset, indices)
loader = DataLoader(datasetsmall, batch_size=batch_size, shuffle=True)

However, since the training is taking huge time to complete I have decided to train the model with only a specific digit from the MNIST dataset, for example the digit 4. How can I just extract the digit 4 and feed it to my neural network in the same way. The loop to train the neural network is like

for batch_idx, (real, _) in enumerate(loader):

Now I want only the digit 4 in the loader. How should I proceed in that case?


Solution

  • Does this code solve your problem?

    import torch
    from torchvision import datasets
    from torch.utils.data import TensorDataset, DataLoader
    from torchvision.transforms import ToTensor
    
    cls = 4 # needed class
    batch_size = 32
    
    dataset = datasets.MNIST(root="dataset/", download=True, transform=ToTensor())
    dataset = list(filter(lambda i: i[1] == cls, dataset))
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    s = 0
    for i in loader:
      s += 1
    
    print(f'We\'ve got {s} batches with batch_size {batch_size} only for class {cls}')
    
    # print(i) # uncomment this line if you want to examine last batch by yourself
    

    Result:

    We've got 183 batches with batch_size 32 only for class 4