Search code examples
pythonpytorchdatasethuggingface-datasets

Why Pytorch Dataset class does not returning list?


I am trying to use torch.utils.Dataset on a custom dataset. In my dataset, in a single row I have a list of 10 images like as follow:

| word | images | gold_image |
|:-----|:-------|:-----------|
|'andromeda'|['image.1.jpg','image.2.jpg','image.3.jpg']|[0,0,1]|

I expect to return batch from dataloader like this, with batch_size=4

('word_1', 'word_2', 'word_3', 'word_4'), ([image_1,image_2,image_3],[image_4,image_5,image_6],[image_7,image_8,image_9], [image_10,image11,image_12]), ([0,0,1],[1,0,0],[0,1,0],[0,1,0])

But, I am getting like this,

('word_1', 'word_2', 'word_3', 'word_4'), [(image_1,image_2,image_3,image_4),(image_5,image_6,image_7,image_8), (image_9,image_10,image_11,image_12)], [(0,1,0,0),(1,0,0,0),(0,1,0,1)]

Here is my code:

class ImageTextDataset(Dataset):
    def __init__(self, data_dir, train_df, tokenizer, feature_extractor, data_type,device, text_augmentation=False):
        self.data_dir = data_dir
        
        if data_type == "train":
            # this is for the original train set of the task
            # reshape all images to size [1440,1810]
            self.tokenizer = tokenizer
            self.feature_extractor=feature_extractor
            self.transforms = transforms.Compose([transforms.Resize([512,512]),transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
            self.all_image_names = list(train_df['images'])
            self.keywords = list(train_df['word'])
            self.context = list(train_df['description'])
            self.gold_images = list(train_df['gold_image'])

    def __len__(self):
        return len(self.context)

    def __getitem__(self, idx):

        context = self.context[idx]
        # print(context)
        keyword = self.keywords[idx]
        #loading images
        
        label = []
        images = self.all_image_names[idx]
        image = []
        for i, img in enumerate(images):
          path = os.path.join(self.data_dir, "trial_images_v1", img)
          img = Image.open(path)
          
          if img.mode != "RGB":
              img = img.convert('RGB')
          img = self.transforms(img)
          image.append(img)
          label.append(1.0) if img == self.gold_images[idx] else label.append(0.0)

        # sample = {'context':context, 'images': images, 'label': label}


      
        return (context, image, label)

I can't figure it out what is the issue. Can anyone help?

TIA.


Solution

  • The DataLoader collates the output of your dataset into batches using the default collate function (implemented in torch/utils/data/_utils/collate.py). What you're observing is the expected behavior when a dataset returns sequence type objects (like lists).

    If you want the dataloader to collate your data differently, then you can provide a custom collate function via the collate_fn argument of DataLoader.

    You can read more about collation and custom collate functions in the documentation.

    The following is a simple example of using a custom collate function that I believe accomplishes what you want, though you may need to change it a bit if it's not exactly what you need.

    import torch
    from torch.utils.data import DataLoader
    from torch.utils.data._utils.collate import default_collate
    
    
    class FakeDataset:
        """ Simple fake dataset for demonstration """
        def __getitem__(self, index):
            context = f'context_{index}'
            images = []
            for i in range(3):
                images.append(torch.full((2, 5, 5), index, dtype=torch.float))
            label = [0.0, 0.0, 0.0]
            label[index % 3] = 1.0
    
            return context, images, label
    
        def __len__(self):
            return 100
    
    
    def my_collate_fn(batch):
        """ batch[list]: each entry assumed to be a tuple returned from FakeDataset.__getitem__ """
        contexts = default_collate([b[0] for b in batch])  # default_collate not actually necessary here
        images = [default_collate(b[1]) for b in batch]
        labels = [default_collate(b[2]) for b in batch]
    
        return contexts, images, labels
    
    
    # define dataloader to use custom collate function
    loader = DataLoader(FakeDataset(), batch_size=4, collate_fn=my_collate_fn)
    
    # get one batch from the dataloader for demonstration
    contexts, images, labels = next(iter(loader))
    
    print('contexts =', contexts)
    print('images (sizes) =', [t.shape for t in images])
    print('labels =', labels)
    

    which prints

    contexts = ['context_0', 'context_1', 'context_2', 'context_3']
    images (sizes) = [torch.Size([3, 2, 5, 5]), torch.Size([3, 2, 5, 5]), torch.Size([3, 2, 5, 5]), torch.Size([3, 2, 5, 5])]
    labels = [tensor([1., 0., 0.], dtype=torch.float64), tensor([0., 1., 0.], dtype=torch.float64), tensor([0., 0., 1.], dtype=torch.float64), tensor([1., 0., 0.], dtype=torch.float64)]
    

    Note that we make use of PyTorch's default_collate function to avoid having to rewrite that logic.