Search code examples
pandaskeraspytorchtorchvisionimage-classification

Loading image data from pandas dataframe with path and class label to PyTorch DataLoader?


I have a dataset of images that I'm looking to apply binary classification to using PyTorch. The training data is in the form of a csv file that has the image path and image label (0 or 1) as the columns of the csv[See below for csv format].

img_path,printer_id,print_id,has_under_extrusion
101/1678589738/1678589914.060332.jpg,101,1678589738,1
101/1678589738/1678589914.462857.jpg,101,1678589738,1
101/1678589738/1678589914.875075.jpg,101,1678589738,1
...
...

I filled a pandas Dataframe with this data using the .read_csv() function. I want to use the information in the Dataframe to fill a PyTorch torch.utils.data.Dataset and ultimately wrap it in a DataLoader while applying transforms to each of the images. I'm looking for something similar to Keras' implementation of the function flow_from_dataframe() Is there a function similar to this offered by PyTorch or would it require a new implementation?


Solution

  • No, in pytorch there isn't samething like flow_from_dataframe. But you can define your own custom dataset easily like this:

    import pandas as pd
    import torch
    from torchvision import transforms
    from PIL import Image
    from torch.utils.data import Dataset, DataLoader
    
    class MyDataset(Dataset):
        def __init__(self, csv_file, transform=None):
            self.data = pd.read_csv(csv_file)
            self.transform = transform
    
        def __len__(self):
            return len(self.data)
    
        def __getitem__(self, index):
            image_path = self.data.img_path[index]
            image = Image.open(image_path)
            label = self.data.has_under_extrusion[index] #I guess this is your class
            if self.transform:
                image = self.transform(image)
            return image, label
    
    # Define transformations to apply to the images
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    
    # Create a PyTorch DataLoader object
    dataset = MyDataset('my_data.csv', transform=transform)
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True)