Search code examples
pythondeep-learningpytorchdistributed

Iterable pytorch dataset with multiple workers


So I have a text file bigger than my ram memory, I would like to create a dataset in PyTorch that reads line by line, so I don't have to load it all at once in memory. I found pytorch IterableDataset as potential solution for my problem. It only works as expected when using 1 worker, if using more than one worker it will create duplicate recods. Let me show you an example:

Having a testfile.txt containing:

0 - Dummy line
1 - Dummy line
2 - Dummy line
3 - Dummy line
4 - Dummy line
5 - Dummy line
6 - Dummy line
7 - Dummy line
8 - Dummy line
9 - Dummy line

Defining a IterableDataset:

class CustomIterableDatasetv1(IterableDataset):

    def __init__(self, filename):

        #Store the filename in object's memory
        self.filename = filename

    def preprocess(self, text):

        ### Do something with text here
        text_pp = text.lower().strip()
        ###

        return text_pp

    def line_mapper(self, line):
        
        #Splits the line into text and label and applies preprocessing to the text
        text, label = line.split('-')
        text = self.preprocess(text)

        return text, label


    def __iter__(self):

        #Create an iterator
        file_itr = open(self.filename)

        #Map each element using the line_mapper
        mapped_itr = map(self.line_mapper, file_itr)
        
        return mapped_itr

We can now test it:

base_dataset = CustomIterableDatasetv1("testfile.txt")
#Wrap it around a dataloader
dataloader = DataLoader(base_dataset, batch_size = 1, num_workers = 1)
for X, y in dataloader:
    print(X,y)

It outputs:



('0',) (' Dummy line\n',)
('1',) (' Dummy line\n',)
('2',) (' Dummy line\n',)
('3',) (' Dummy line\n',)
('4',) (' Dummy line\n',)
('5',) (' Dummy line\n',)
('6',) (' Dummy line\n',)
('7',) (' Dummy line\n',)
('8',) (' Dummy line\n',)
('9',) (' Dummy line',)

That is correct. But If I change the number of workers to 2 the output becomes

('0',) (' Dummy line\n',)
('0',) (' Dummy line\n',)
('1',) (' Dummy line\n',)
('1',) (' Dummy line\n',)
('2',) (' Dummy line\n',)
('2',) (' Dummy line\n',)
('3',) (' Dummy line\n',)
('3',) (' Dummy line\n',)
('4',) (' Dummy line\n',)
('4',) (' Dummy line\n',)
('5',) (' Dummy line\n',)
('5',) (' Dummy line\n',)
('6',) (' Dummy line\n',)
('6',) (' Dummy line\n',)
('7',) (' Dummy line\n',)
('7',) (' Dummy line\n',)
('8',) (' Dummy line\n',)
('8',) (' Dummy line\n',)
('9',) (' Dummy line',)
('9',) (' Dummy line',)

Which is incorrect, as is creating duplicates of each sample per worker in the data loader.

Is there a way to solve this issue with pytorch? So a dataloader can be created to not load all file in memory with support for multiple workers.


Solution

  • So I found an answer within the torch discuss forum https://discuss.pytorch.org/t/iterable-pytorch-dataset-with-multiple-workers/135475/3 where they pointed out I should be using the worker info to slice consecutively to the batch size.

    The new dataset would look like this:

    class CustomIterableDatasetv1(IterableDataset):
    
        def __init__(self, filename):
    
            #Store the filename in object's memory
            self.filename = filename
    
        def preprocess(self, text):
    
            ### Do something with text here
            text_pp = text.lower().strip()
            ###
    
            return text_pp
    
        def line_mapper(self, line):
            
            #Splits the line into text and label and applies preprocessing to the text
            text, label = line.split('-')
            text = self.preprocess(text)
    
            return text, label
    
    
        def __iter__(self):
            worker_total_num = torch.utils.data.get_worker_info().num_workers
            worker_id = torch.utils.data.get_worker_info().id
            #Create an iterator
            file_itr = open(self.filename)
    
            #Map each element using the line_mapper
            mapped_itr = map(self.line_mapper, file_itr)
            
            #Add multiworker functionality
            mapped_itr = itertools.islice(mapped_itr, worker_id, None, worker_total_num)
    
            return mapped_itr
    

    Special thanks to @Ivan who also pointed out the slicing solution.

    With two workers it returns the same data as 1 worker only