I have some pretty simple code where I download a dataset from the HuggingFace datasets library, try to resize it, and create a dataloader from it. However, I get error messages that the images cannot be stacked as they are still of the original size in the dataset, despite the resize() function being applied:
from torchvision.transforms import ToTensor, Compose, Resize
from torch.utils.data import DataLoader
from datasets import load_dataset
dataset_name = 'food101'
size = 255
resize = Compose([Resize(size),ToTensor()])
dataset = load_dataset(dataset_name, split='train')
dataset.set_transform(resize)
dataset.set_format('torch')
dataloader = DataLoader(dataset, batch_size=32)
for batch in dataloader:
print(inputs)
I get the following error:
RuntimeError: stack expects each tensor to be equal size, but got [512, 384, 3] at entry 0 and [512, 512, 3] at entry 1
I'm extremely confused here. Whether I use set_transform() or with_transform(), it does not seem like the transformation is ever actually applied. What am I doing wrong here?
I also tried applying it with a function like this, which didn't make a difference:
def transform(examples):
examples['image'] = [resize(img) for img in examples['image']]
return examples
dataset.set_transform(transform)
First, according to the datasets docs the dataset.set_format
method resets the transformations. So, since you are transforming the images to Pytorch tensor inside the resize
transforms, I believe there is no need for set_format
. (But you could still apply it before the set_transform
just to make sure)
Second, if images are all of different height and length sizes, you should provide both dimensions to the Resize((size, size))
transforms.
overall, this would work:
resize = Compose([Resize((size, size)),ToTensor()])
def transform(examples):
examples['image'] = [resize(img) for img in examples['image']]
return examples
# dataset.set_format('torch')
dataset.set_transform(transform)
dataloader = DataLoader(dataset, batch_size=32)
Note that the transform(examples)
function is still required.
for batch in dataloader:
print(batch.keys())
print(batch['image'].shape)
print(batch['label'].shape)
break
>>>
dict_keys(['image', 'label'])
torch.Size([32, 3, 255, 255])
torch.Size([32])