Search code examples
pythonimportpytorchclassification

Loading train/val/test datasets with images in separate folders using Pytorch


For my first Pytorch project, I have to perform image classification using a dataset containing jpg image of clouds. Im am struggling with data importation, because the train/validation/test sets are not separated and the images are located in different folders according to their class. So, the folders structure looks like this:

-dataset_folder
    -Class_1
        img1
        img2
        ...
    -Class_2
        img1
        img2
        ...
    -Class_3
        img1
        img2
        ...
    -Class_4
        img1
        img2
        ...

I saw that the ImageFolder() class could handle this kind of folder structure, but I have no idea how to combine this with separating the dataset into 3 parts.

Can someone please show me a way to do this ?


Solution

  • You can write a custom Dataset class to load your data and use it in your project:

    import os
    import glob
    import torch
    from torch.utils.data import Dataset
    from PIL import Image
    from torchvision.transforms import ToTensor
    
    class CustomImageDataset(Dataset):
        def __init__(self, root_dir, transform=None):
            self.root_dir = root_dir
            self.transform = transform
            self.class_folders = [f for f in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, f))]
            self.image_paths = []
            self.labels = []
    
            for label, class_folder in enumerate(self.class_folders):
                img_paths = glob.glob(os.path.join(root_dir, class_folder, '*.jpg'))
                self.image_paths.extend(img_paths)
                self.labels.extend([label] * len(img_paths))
    
        def __len__(self):
            return len(self.image_paths)
    
        def __getitem__(self, idx):
            img_path = self.image_paths[idx]
            label = self.labels[idx]
            image = Image.open(img_path).convert('RGB')
    
            if self.transform:
                image = self.transform(image)
    
            return image, label
    

    Check out this link for more details about the custom datasets. https://pytorch.org/tutorials/beginner/basics/data_tutorial.html#creating-a-custom-dataset-for-your-files

    After that, you can split your dataset into as many parts as you want. here is a nice answer for how to split a custom dataset into different sets using SubsetRandomSampler: How do I split a custom dataset into training and test datasets?