Search code examples
datasetpytorchdataloader

Pytorch Dataloader for Image GT dataset


I am new to pytorch. I am trying to create a DataLoader for a dataset of images where each image got a corresponding ground truth (same name):

root:
--->RGB:
------>img1.png
------>img2.png
------>...
------>imgN.png
--->GT:
------>img1.png
------>img2.png
------>...
------>imgN.png

When I use the path for root folder (that contains RGB and GT folders) as input for the torchvision.datasets.ImageFolder it reads all of the images as if they were all intended for input (classified as RGB and GT), and it seems like there is no way to pair the RGB-GT images. I would like to pair the RGB-GT images, shuffle, and divide it to batches of defined size. How can it be done? Any advice will be appreciated. Thanks.


Solution

  • I think, the good starting point is to use VisionDataset class as a base. What we are going to use here is: DatasetFolder source code. So, we going to create smth similar. You can notice this class depends on two other functions from datasets.folder module: default_loader and make_dataset.

    We are not going to modify default_loader, because it's already fine, it just helps us to load images, so we will import it.

    But we need a new make_dataset function, that prepared the right pairs of images from root folder. Since original make_dataset pairs images (image paths if to be more precisely) and their root folder as target class (class index) and we have a list of (path, class_to_idx[target]) pairs, but we need (rgb_path, gt_path). Here is the code for new make_dataset:

    def make_dataset(root: str) -> list:
        """Reads a directory with data.
        Returns a dataset as a list of tuples of paired image paths: (rgb_path, gt_path)
        """
        dataset = []
    
        # Our dir names
        rgb_dir = 'RGB'
        gt_dir = 'GT'   
    
        # Get all the filenames from RGB folder
        rgb_fnames = sorted(os.listdir(os.path.join(root, rgb_dir)))
    
        # Compare file names from GT folder to file names from RGB:
        for gt_fname in sorted(os.listdir(os.path.join(root, gt_dir))):
    
                if gt_fname in rgb_fnames:
                    # if we have a match - create pair of full path to the corresponding images
                    rgb_path = os.path.join(root, rgb_dir, gt_fname)
                    gt_path = os.path.join(root, gt_dir, gt_fname)
    
                    item = (rgb_path, gt_path)
                    # append to the list dataset
                    dataset.append(item)
                else:
                    continue
    
        return dataset
    

    What do we have now? Let's compare our function with original one:

    from torchvision.datasets.folder import make_dataset as make_dataset_original
    
    
    dataset_original = make_dataset_original(root, {'RGB': 0, 'GT': 1}, extensions='png')
    dataset = make_dataset(root)
    
    print('Original make_dataset:')
    print(*dataset_original, sep='\n')
    
    print('Our make_dataset:')
    print(*dataset, sep='\n')
    
    Original make_dataset:
    ('./data/GT/img1.png', 1)
    ('./data/GT/img2.png', 1)
    ...
    ('./data/RGB/img1.png', 0)
    ('./data/RGB/img2.png', 0)
    ...
    Our make_dataset:
    ('./data/RGB/img1.png', './data/GT/img1.png')
    ('./data/RGB/img2.png', './data/GT/img2.png')
    ...
    
    

    I think it works great) It's time to create our class Dataset. The most important part here is __getitem__ methods, because it imports images, applies transformation and returns a tensors, that can be used by dataloaders. We need to read a pair of images (rgb and gt) and return a tuple of 2 tensor images:

    from torchvision.datasets.folder import default_loader
    from torchvision.datasets.vision import VisionDataset
    
    
    class CustomVisionDataset(VisionDataset):
    
        def __init__(self,
                     root,
                     loader=default_loader,
                     rgb_transform=None,
                     gt_transform=None):
            super().__init__(root,
                             transform=rgb_transform,
                             target_transform=gt_transform)
    
            # Prepare dataset
            samples = make_dataset(self.root)
    
            self.loader = loader
            self.samples = samples
            # list of RGB images
            self.rgb_samples = [s[1] for s in samples]
            # list of GT images
            self.gt_samples = [s[1] for s in samples]
    
        def __getitem__(self, index):
            """Returns a data sample from our dataset.
            """
            # getting our paths to images
            rgb_path, gt_path = self.samples[index]
    
            # import each image using loader (by default it's PIL)
            rgb_sample = self.loader(rgb_path)
            gt_sample = self.loader(gt_path)
    
            # here goes tranforms if needed
            # maybe we need different tranforms for each type of image
            if self.transform is not None:
                rgb_sample = self.transform(rgb_sample)
            if self.target_transform is not None:
                gt_sample = self.target_transform(gt_sample)      
    
            # now we return the right imported pair of images (tensors)
            return rgb_sample, gt_sample
    
        def __len__(self):
            return len(self.samples)
    

    Let's test it:

    from torch.utils.data import DataLoader
    
    from torchvision.transforms import ToTensor
    import matplotlib.pyplot as plt
    
    
    bs=4  # batch size
    transforms = ToTensor()  # we need this to convert PIL images to Tensor
    shuffle = True
    
    dataset = CustomVisionDataset('./data', rgb_transform=transforms, gt_transform=transforms)
    dataloader = DataLoader(dataset, batch_size=bs, shuffle=shuffle)
    
    for i, (rgb, gt) in enumerate(dataloader):
        print(f'batch {i+1}:')
        # some plots
        for i in range(bs):
            plt.figure(figsize=(10, 5))
            plt.subplot(221)
            plt.imshow(rgb[i].squeeze().permute(1, 2, 0))
            plt.title(f'RGB img{i+1}')
            plt.subplot(222)
            plt.imshow(gt[i].squeeze().permute(1, 2, 0))
            plt.title(f'GT img{i+1}')
            plt.show()
    

    Out:

    batch 1:
    

    a b c

    ...

    Here you can find a notebook with code and simple dummy dataset.