Search code examples
pythonpytorchdataloader

How to make the trainloader use a specific amount of images?


Assume I am using the following calls:

trainset = torchvision.datasets.ImageFolder(root="imgs/", transform=transform)
trainloader = torch.utils.data.DataLoader(trainset,batch_size=4,shuffle=True,num_workers=1)

As far as I can tell, this defines the trainset as consisting of all the images in the folder "images", with labels as defined by the specific folder location.

My question is - Is there any direct/easy way to define the trainset to be a sub-sample of the images in this folder? For example, define trainset to be a random sample of 10 images from every sub-folder?


Solution

  • You can wrap the class DatasetFolder (or ImageFolder) in another class to limit the dataset:

    class LimitDataset(data.Dataset):
        def __init__(self, dataset, n):
            self.dataset = dataset
            self.n = n
    
        def __len__(self):
            return self.n
    
        def __getitem__(self, i):
            return self.dataset[i]
    

    You can also define some mapping between the index in LimitDataset and the index in the original dataset to define more complex behavior (such as random subsets).

    If you want to limit the batches per epoch instead of the dataset size:

    from itertools import islice
    for data in islice(dataloader, 0, batches_per_epoch):
        ...
    

    Note that if you use this shuffle, the dataset size will be the same, but the data that each epoch will see will be limited. If you don't shuffle the dataset this will also limit the dataset size.