Search code examples
pythondatasetpytorchtorchtorchvision

Python, class dataset, how to concatenate images with their respective labels in pytorch


I am new to PyTorch, and in the last couple of days I have been struggling with the class Dataset that lets you build your custom dataset.

I am working with this dataset (https://www.kaggle.com/ianmoone0617/flower-goggle-tpu-classification/kernels) , the problem is that it has the images and their labels in separate folders, and I can’t figure out how to concatenate them.

This is the code I am using:

class MyDataset(Dataset):

    def __init__(self, csv_file, root_dir, transform=None):
        self.labels = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        if torch.is_tensor(index):
            index = index.tolist()

        image_name = os.path.join(self.root_dir, self.labels.iloc[index, 0])
        image = io.imread(image_name)

        if self.transform:
            image = self.transform(image)

        return (image, labels)

While the structure of the folders is the following one: structure of the folders]

I really want to understand this so thank you in advance guys!!


Solution

  • Seems like you're nearly there. There are many ways to deal with this. For example, you could read both csv files during initialization to build a dictionary which maps the label string in the flowers_idx.csv to the label index specified in flowers_label.csv.

    import os
    import pandas as pd
    import torch
    from torchvision.datasets.folder import default_loader
    from torch.utils.data import Dataset
    
    class MyDataset(Dataset):
        def __init__(self, data_csv, label_csv, root_dir, transform=None):
            self.data_entries = pd.read_csv(data_csv)
            self.root_dir = root_dir
            self.transform = transform
    
            label_map = pd.read_csv(label_csv)
            self.label_str_to_idx = {label_str: label_idx for label_idx, label_str in label_map.iloc}
    
        def __len__(self):
            return len(self.labels)
    
        def __getitem__(self, index):
            if torch.is_tensor(index):
                index = index.item()
    
            label = self.label_str_to_idx[self.data_entries.iloc[index, 1]] 
            image_path = os.path.join(self.root_dir, f'{self.data_entries.iloc[index, 0]}.jpeg')
    
            # torchvision datasets generally return PIL image rather than numpy ndarray
            image = default_loader(image_path)
    
            # alternative to load ndarray using skimage.io
            # image = io.imread(image_path)
    
            if self.transform:
                image = self.transform(image)
    
            return (image, label)
    

    Note that this returns PIL images rather than ndarrays since that's generally what is returned by torchvision datasets. This is also nice since many of the torchvision transforms can only be appled to PIL images.

    For now a simple use case could be:

    import torchvision.transforms as tt
    
    dataset_dir = '/home/jodag/datasets/527293_966816_bundle_archive'
    # TODO add more transforms/data-augmentation etc...
    transform = tt.Compose((
        tt.ToTensor(),
    ))
    
    dataset = MyDataset(
        os.path.join(dataset_dir, 'flowers_idx.csv'),
        os.path.join(dataset_dir, 'flowers_label.csv'),
        os.path.join(dataset_dir, 'flower_tpu/flower_tpu/flowers_google/flowers_google'),
        transform)
    
    image, label = dataset[0]
    

    During training or validation you would probably use a DataLoader to sample the dataset.