Search code examples
pythonmachine-learningdeep-learningpytorchdataset

override pytorch Dataset efficiently


I want to inherit the torch.utils.data.Dataset class to load my custom image dataset, let's say for a classification task. here is the example of official pytorch website in this link:

import os
import pandas as pd
from torchvision.io import read_image

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

I have noticed that:

  1. in the __getitem__ we are reading an image from disk to memory. It means if we train our model for several epochs, we are re-reading the same image into memory several times. To my knowledge it is a costly action
  2. a transform is applied each time an image is read from disk and that seems to me a nearly redundant action.

I undrestand in very big datasets, we cannot fit the data fully into the memory and thus we have no choice but to read it this way (as we must iterate over all data in an epoch) and I was wondering, in the case that all my data can be fit into memory, isn't reading it all from the disk in the __init__ function a better approach?

Through my little experience in computer vision I have noticed that croping images into fixed size images is very recurring in the transform. Then why shouldn't we crop the images once and store it on the disk somewhere else and throughout training only read the cropped images? This seems a more efficient approach to me.

I undrestand some transforms such as those used for augmentation rather than normalization would be better to be applied in the __getitem__ to have a randomly generated data rather than a fixed one.

Can you clarify the subject for me? If it is a common knowledge that I'm missing, please guide me to codebases with the proper approach.


Solution

    1. in the getitem we are reading an image from disk to memory. It means if we train our model for several epochs, we are re-reading the same image into memory several times. To my knowledge it is a costly action

    An alternative would be to cache the sample when it is first read, such that by the end of the first epoch all of the sample will be cached in memory for faster subsequent access. However, this would require enough RAM to hold the entire dataset. This is probably the limitation that the example is trying to circumvent - they read each sample each time it is needed because of memory limitations (i.e. trading off repeated reads vs available memory).

    1. a transform is applied each time an image is read from disk and that seems to me a nearly redundant action.

    Transforms usually have randomness built-in as it helps prevent the net from overfitting. That's why we transform each time a sample is requested - we need a different random transformation each time. If the transformations were only applied a single time and re-used thereafter, it would defeat the purpose of random augmentation as the net could just learn the single static transformation.

    [...] in the case that all my data can be fit into memory, isn't reading it all from the disk in the init function a better approach?

    I think that would be worth trying if disk-to-RAM speed is a bottleneck in your pipeline. Pre-loading all the data once would result in a speed improvement in that case. Bear in mind that even though your data might fit in RAM, by the time it's going through the model, the RAM requirements will be higher due to the model's size and training gradients (there will be different RAM requirements at training vs inference, and CPU vs GPU).

    [...] why shouldn't we crop the images once and store it on the disk somewhere else and throughout training only read the cropped images?

    Random cropping is usually used to prevent the net from simply memorising the exact features of an image. The net is forced to generalise beyond the randomness to more robust and general properties. Cropping once and re-using the same image would mean the net could learn the single crop and not generalise as well to unseen data.