Search code examples
python-3.xpytorchconv-neural-networkpytorch-dataloaderefficientnet

can i use PyTorch Data Loader to load raw data images which are saved in CSV files?


I have raw data images saved in separate CSV files(each image in a file). I want to train a CNN on them using PyTorch. how should I load data to be appropriate for using as CNN's input? (also, it is 1 channel and the image net's input is RGB as the default)


Solution

  • PyTorch's DataLoader, as the name suggests, is simply a utility class that helps you load your data in parallel, build your batch, shuffle and so on, what you need is instead a custom Dataset implementation.

    Ignoring the fact that images stored in CSV files is kind of weird, you simply need something of the sort:

    from torch.utils.data import Dataset, DataLoader
    
    
    class CustomDataset(Dataset):
    
        def __init__(self, path: Path, ...):
            # do some preliminary checks, e.g. your path exists, files are there...
            assert path.exists()
            ...
            # retrieve your files in some way, e.g. glob
            self.csv_files = list(glob.glob(str(path / "*.csv")))
    
        def __len__(self) -> int:
            # this lets you know len(dataset) once you instantiate it
            return len(self.csv_files)
    
    
        def __getitem__(self, index: int) -> Any:
            # this method is called by the dataloader, each index refers to
            # a CSV file in the list you built in the constructor
            csv = self.csv_files[index]
            # now do whatever you need to do and return some tensors
            image, label = self.load_image(csv)
            return image, label
    
    

    And that's it, more or less. You can then create your dataset, pass it to a dataloader and iterate the dataloader, something like:

    dataset = CustomDataset(Path("path/to/csv/files"))
    train_loader = DataLoader(dataset, shuffle=True, num_workers=8,...)
    
    for batch in train_loader:
        ...