Search code examples
pythonimage-processingpytorchtorchvisionhuggingface-datasets

datasets.Dataset.set_transform() doesn't seem to apply transformations to images


I have some pretty simple code where I download a dataset from the HuggingFace datasets library, try to resize it, and create a dataloader from it. However, I get error messages that the images cannot be stacked as they are still of the original size in the dataset, despite the resize() function being applied:


from torchvision.transforms import ToTensor, Compose, Resize
from torch.utils.data import DataLoader
from datasets import load_dataset

dataset_name = 'food101'
size = 255

resize = Compose([Resize(size),ToTensor()])

dataset = load_dataset(dataset_name, split='train')
dataset.set_transform(resize)
dataset.set_format('torch')
dataloader = DataLoader(dataset, batch_size=32)

for batch in dataloader:
  print(inputs)

I get the following error:

RuntimeError: stack expects each tensor to be equal size, but got [512, 384, 3] at entry 0 and [512, 512, 3] at entry 1

I'm extremely confused here. Whether I use set_transform() or with_transform(), it does not seem like the transformation is ever actually applied. What am I doing wrong here?

I also tried applying it with a function like this, which didn't make a difference:

def transform(examples):
  examples['image'] = [resize(img) for img in examples['image']]
  return examples

dataset.set_transform(transform)

Solution

  • First, according to the datasets docs the dataset.set_format method resets the transformations. So, since you are transforming the images to Pytorch tensor inside the resize transforms, I believe there is no need for set_format. (But you could still apply it before the set_transform just to make sure)

    Second, if images are all of different height and length sizes, you should provide both dimensions to the Resize((size, size)) transforms.

    overall, this would work:

    resize = Compose([Resize((size, size)),ToTensor()])
    
    def transform(examples):
      examples['image'] = [resize(img) for img in examples['image']]
      return examples
    
    # dataset.set_format('torch')
    dataset.set_transform(transform)
    dataloader = DataLoader(dataset, batch_size=32)
    

    Note that the transform(examples) function is still required.

    for batch in dataloader:
        print(batch.keys())
        print(batch['image'].shape)
        print(batch['label'].shape)
        break
    
    >>>
    dict_keys(['image', 'label'])
    torch.Size([32, 3, 255, 255])
    torch.Size([32])