Search code examples
pythonneural-networkcomputer-visionpytorchdistributed-computing

Sample each image from dataset N times in single batch


I'm currently working on task of learning representation (deep embeddings). The dataset I use have only one example image per object. I also use augmentation.

During training, each batch must contain N different augmented versions of single image in dataset (dataset[index] always returns new random transformation).

Is there some standart solution or library with DataLoader for this purpose, that will work with torch.utils.data.distributed.DistributedSampler? If not, will any DataLoader, inherited from torch.utils.data.DataLoader (and calling super().__init__(...)), will work in distributed training?


Solution

  • As far as I know, this is not a standard way of doing things — even if you have only one sample per object, one would still sample different images from different object per batch, and in different epochs the sampled images would be transformed differently.

    That said, if you truly want to do what you are doing, why not simply write a wrapper of you dataset?

    class Wrapper(Dataset):
        N = 16
        def __getitem__(self, index):
            sample = [ super().__getitem__(index) for _ in N ]
            sample = torch.stack(sample, dim=0)
            return sample
    

    Then each of your batch would be BxNxCxHxW where B is the batch size, N is your repetition. You can reshape your batch after you get it from the dataloader.