Search code examples
pythonpytorchdataset

PyTorch ImageFolder vs. Custom Dataset from single folder


I do have a image multi-classification problem, where all my images are stored in one folder and the label for each image is within its filename.

I'm new to PyTorch and was wondering why there is (as far as I know) only one method like ImageFolder() to build a Dataset? It seems to me quite cumbersome, to restructure my images according to ImageFolder(), with its predifined train and test folders.

Is it a reasonable structure to have all images in one folder with their label in their filename? If so why is there no Dataset method like ImageFromOneFolder()?

I guess the way to go is making a cutom Dataset.

Thanks for your help


Solution

  • A custom dataset would work:

    from PIL import Image
    from torch.utils.data import Dataset
    from pathlib import Path
    from torch.utils.data import DataLoader
    from torchvision import transforms
    
    
    class ImageFolderCustom(Dataset):
    
        def __init__(self, targ_dir, transform=None):
            self.paths = list(Path(targ_dir).glob("*.jpg"))
            self.transform = transform
            self.classes = sorted(list(set(map(self.get_label, self.paths))))
            
        @staticmethod
        def get_label(path):
            # make sure this function returns the label from the path
            return str(path.with_suffix('').name)[-1]
    
        def load_image(self, index):
            image_path = self.paths[index]
            return Image.open(image_path)
    
        def __len__(self):
            return len(self.paths)
    
        def __getitem__(self, index):
            img = self.load_image(index)
            class_name = self.get_label(self.paths[index])
            class_idx = self.classes.index(class_name)
    
            if self.transform:
                return self.transform(img), class_idx
            else:
                return img, class_idx
    
    
    train_transforms = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor()
    ])
    
    dataset = ImageFolderCustom(
        'path/to/data',
        transform=train_transforms
    )
    
    train_dataloader_custom = DataLoader(
        dataset=dataset,
        batch_size=4,
        shuffle=True
    )
    
    images, labels = next(iter(train_dataloader_custom))
    print(labels)