Search code examples
pythonpytorchiteratordataloaderpytorch-dataloader

Best way to use Python iterator as dataset in PyTorch


The PyTorch DataLoader turns datasets into iterables. I already have a generator which yields data samples that I want to use for training and testing. The reason I use a generator is because the total number of samples is too large to store in memory. I would like to load the samples in batches for training.

What is the best way to do this? Can I do it without a custom DataLoader? The PyTorch dataloader doesn't like taking the generator as input. Below is a minimal example of what I want to do, which produces the error "object of type 'generator' has no len()".

import torch
from torch import nn
from torch.utils.data import DataLoader

def example_generator():
    for i in range(10):
        yield i
    

BATCH_SIZE = 3
train_dataloader = DataLoader(example_generator(),
                        batch_size = BATCH_SIZE,
                        shuffle=False)

print(f"Length of train_dataloader: {len(train_dataloader)} batches of {BATCH_SIZE}")

I am trying to take the data from an iterator and take advantage of the functionality of the PyTorch DataLoader. The example I gave is a minimal example of what I would like to achieve, but it produces an error.

Edit: I want to be able to use this function for complex generators in which the len is not known in advance.


Solution

  • PyTorch's DataLoader actually has official support for an iterable dataset, but it just has to be an instance of a subclass of torch.utils.data.IterableDataset:

    An iterable-style dataset is an instance of a subclass of IterableDataset that implements the __iter__() protocol, and represents an iterable over data samples

    So your code would be written as:

    from torch.utils.data import IterableDataset
    
    class MyIterableDataset(IterableDataset):
        def __init__(self, iterable):
            self.iterable = iterable
    
        def __iter__(self):
            return iter(self.iterable)
    
    ...
    
    BATCH_SIZE = 3
    
    train_dataloader = DataLoader(MyIterableDataset(example_generator()),
                                  batch_size = BATCH_SIZE,
                                  shuffle=False)