Search code examples
pythonmachine-learningpytorch

Get a smaller MNIST dataset in pytorch


This is how I load the dataset but the dataset is too big. There are about 60k images. so I would like to limit it to 1/10 for training. Is there any built-in method I can do that?

from torchvision import datasets
import torchvision.transforms as transforms
train_data = datasets.MNIST(
    root='data',
    train=True,
    transform=transforms.Compose(
        [transforms.ToTensor()]
    ),
    download=True
)

print(train_data)

print(train_data.data.size())
print(train_data.targets.size())



loaders = {
    'train': DataLoader(train_data,
                        batch_size=100),
}

Solution

  • aretor's answer doesn't shuffle the data, and Prajot's answer wastefully creates a test set. Here's a better solution IMO using SubsetRandomSampler:

    from torch.utils.data import DataLoader, SubsetRandomSampler
    
    K = 6000 # enter your length here
    subsample_train_indices = torch.randperm(len(train_data))[:K]
    train_loader = DataLoader(train_data, batch_size=batch_size, sampler=SubsetRandomSampler(subsample_train_indices))