Search code examples
pythonpytorchpytorch-dataloader

Dataloader on Iterable dataset yields copied batches for num_workers > 0


The title says it all. An iterable dataset with a multi-worker dataloader yields more batches than it should (seems that each worker yields all the batches separately). Here is an MWE:

import torch


class ToyDataset(torch.utils.data.IterableDataset):
    def __iter__(self):
        data = torch.arange(len(self))
        yield from data

    def __len__(self):
        return 386


dataset = ToyDataset()
loader = torch.utils.data.DataLoader(dataset, batch_size=256, num_workers=2)
print(len(loader), len(list(loader))) # 2 4

Is it something I'm missing? Is this a bug in pytorch (though it seems highly unlikely)? And most importantly, is there any way around this?

I also created an issue on the pytorch discuss forums, however it didn't get much attention.


Solution

  • This behavior is expected and explained in the IterableDataset documentation:

    When a subclass is used with DataLoader, each item in the dataset will be yielded from the DataLoader iterator. When num_workers > 0, each worker process will have a different copy of the dataset object, so it is often desired to configure each copy independently to avoid having duplicate data returned from the workers. get_worker_info(), when called in a worker process, returns information about the worker. It can be used in either the dataset’s iter() method or the DataLoader ‘s worker_init_fn option to modify each copy’s behavior.

    The linked documentation page also gives two examples for using an IterableDataset with multiple workers. One using worker info in the __iter__ method of the dataset, the other using the worker_init_fn for the dataloader.

    As a simple example:

    import torch
    import math
    
    class ToyDataset(torch.utils.data.IterableDataset):
        def __init__(self, start, end):
            self.start = start
            self.end = end
        
        def __iter__(self):
            worker_info = torch.utils.data.get_worker_info()
            
            if worker_info is None:
                iter_start = self.start
                iter_end = self.end
            else:
                per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
                worker_id = worker_info.id
                iter_start = self.start + worker_id * per_worker
                iter_end = min(iter_start + per_worker, self.end)
            data = torch.arange(iter_start, iter_end)
            yield from data
    
    dataset = ToyDataset(0, 386)
    
    loader = torch.utils.data.DataLoader(dataset, batch_size=256, num_workers=2)
    
    print(len(list(loader)))
    > 2