Search code examples
pytorchpytorch-lightning

How to convert a generator to a Pytorch Dataloader?


I have a generator that creates synthetic data. How can I convert this into a PyTorch dataloader?


Solution

  • You can wrap your generator with a data.IterableDataset:

    class IterDataset(data.IterableDataset):
        def __init__(self, generator):
            self.generator = generator
    
        def __iter__(self):
            return self.generator
    

    You can then wrap this dataset with a data.DataLoader.

    Here is a minimal example showing its use:

    >>> gen = (x for x in range(10))
    
    >>> dataset = IterDataset(gen)
    >>> for i in data.DataLoader(dataset, batch_size=2):
    ...    print(i)
    tensor([0, 1])
    tensor([2, 3])
    tensor([4, 5])
    tensor([6, 7])
    tensor([8, 9])