Search code examples
pytorchiteratorpytorch-dataloader

Writing a custom pytorch dataloader iter with pre-processing on batch


A typical custom PyTorch Dataset looks like this,

class TorchCustomDataset(torch.utils.data.Dataset):

    def __init__(self, filenames, speech_labels):
        pass

    def __len__(self):
        return 100

    def __getitem__(self, idx):
        return 1, 0

Here, with __getitem__ I can read any file, and apply any pre-processing for that specific file.

What if I want to apply some tensor-level pre-processing to the whole batch of data? Technically, it's possible to just iterate through the data loader to get the batch sample and apply the pre-processing on it.

But how to do it with a custom data loader? In short, what will be the __getitem__ equivalent for data loader to apply some operation on the whole batch of data?


Solution

  • You can override the collate_fn of DataLoader: This function takes the individual items from the underlying Dataset and forms the batch. You can add your custom pre-processing at that point by modifying the collate_fn.