I am trying to use torch.utils.Dataset on a custom dataset. In my dataset, in a single row I have a list of 10 images like as follow:
| word | images | gold_image |
|:-----|:-------|:-----------|
|'andromeda'|['image.1.jpg','image.2.jpg','image.3.jpg']|[0,0,1]|
I expect to return batch from dataloader like this, with batch_size=4
('word_1', 'word_2', 'word_3', 'word_4'), ([image_1,image_2,image_3],[image_4,image_5,image_6],[image_7,image_8,image_9], [image_10,image11,image_12]), ([0,0,1],[1,0,0],[0,1,0],[0,1,0])
But, I am getting like this,
('word_1', 'word_2', 'word_3', 'word_4'), [(image_1,image_2,image_3,image_4),(image_5,image_6,image_7,image_8), (image_9,image_10,image_11,image_12)], [(0,1,0,0),(1,0,0,0),(0,1,0,1)]
Here is my code:
class ImageTextDataset(Dataset):
def __init__(self, data_dir, train_df, tokenizer, feature_extractor, data_type,device, text_augmentation=False):
self.data_dir = data_dir
if data_type == "train":
# this is for the original train set of the task
# reshape all images to size [1440,1810]
self.tokenizer = tokenizer
self.feature_extractor=feature_extractor
self.transforms = transforms.Compose([transforms.Resize([512,512]),transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
self.all_image_names = list(train_df['images'])
self.keywords = list(train_df['word'])
self.context = list(train_df['description'])
self.gold_images = list(train_df['gold_image'])
def __len__(self):
return len(self.context)
def __getitem__(self, idx):
context = self.context[idx]
# print(context)
keyword = self.keywords[idx]
#loading images
label = []
images = self.all_image_names[idx]
image = []
for i, img in enumerate(images):
path = os.path.join(self.data_dir, "trial_images_v1", img)
img = Image.open(path)
if img.mode != "RGB":
img = img.convert('RGB')
img = self.transforms(img)
image.append(img)
label.append(1.0) if img == self.gold_images[idx] else label.append(0.0)
# sample = {'context':context, 'images': images, 'label': label}
return (context, image, label)
I can't figure it out what is the issue. Can anyone help?
TIA.
The DataLoader
collates the output of your dataset into batches using the default collate function (implemented in torch/utils/data/_utils/collate.py). What you're observing is the expected behavior when a dataset returns sequence type objects (like lists).
If you want the dataloader to collate your data differently, then you can provide a custom collate function via the collate_fn
argument of DataLoader
.
You can read more about collation and custom collate functions in the documentation.
The following is a simple example of using a custom collate function that I believe accomplishes what you want, though you may need to change it a bit if it's not exactly what you need.
import torch
from torch.utils.data import DataLoader
from torch.utils.data._utils.collate import default_collate
class FakeDataset:
""" Simple fake dataset for demonstration """
def __getitem__(self, index):
context = f'context_{index}'
images = []
for i in range(3):
images.append(torch.full((2, 5, 5), index, dtype=torch.float))
label = [0.0, 0.0, 0.0]
label[index % 3] = 1.0
return context, images, label
def __len__(self):
return 100
def my_collate_fn(batch):
""" batch[list]: each entry assumed to be a tuple returned from FakeDataset.__getitem__ """
contexts = default_collate([b[0] for b in batch]) # default_collate not actually necessary here
images = [default_collate(b[1]) for b in batch]
labels = [default_collate(b[2]) for b in batch]
return contexts, images, labels
# define dataloader to use custom collate function
loader = DataLoader(FakeDataset(), batch_size=4, collate_fn=my_collate_fn)
# get one batch from the dataloader for demonstration
contexts, images, labels = next(iter(loader))
print('contexts =', contexts)
print('images (sizes) =', [t.shape for t in images])
print('labels =', labels)
which prints
contexts = ['context_0', 'context_1', 'context_2', 'context_3']
images (sizes) = [torch.Size([3, 2, 5, 5]), torch.Size([3, 2, 5, 5]), torch.Size([3, 2, 5, 5]), torch.Size([3, 2, 5, 5])]
labels = [tensor([1., 0., 0.], dtype=torch.float64), tensor([0., 1., 0.], dtype=torch.float64), tensor([0., 0., 1.], dtype=torch.float64), tensor([1., 0., 0.], dtype=torch.float64)]
Note that we make use of PyTorch's default_collate
function to avoid having to rewrite that logic.