I have a generator that creates synthetic data. How can I convert this into a PyTorch dataloader?
You can wrap your generator with a data.IterableDataset
:
class IterDataset(data.IterableDataset):
def __init__(self, generator):
self.generator = generator
def __iter__(self):
return self.generator
You can then wrap this dataset with a data.DataLoader
.
Here is a minimal example showing its use:
>>> gen = (x for x in range(10))
>>> dataset = IterDataset(gen)
>>> for i in data.DataLoader(dataset, batch_size=2):
... print(i)
tensor([0, 1])
tensor([2, 3])
tensor([4, 5])
tensor([6, 7])
tensor([8, 9])