Search code examples
pythonpytorchdataloader

PyTorch dataloader shows odd behavior with string dataset


I'm working on an NLP problem and am using PyTorch. For some reason, my dataloader is returning malformed batches. I have input data that comprises sentences and integer labels. The sentences can either a list of sentences or a list of list of tokens. I will later convert the tokens to integers in a downstream component.

list_labels = [ 0, 1, 0]

# List of sentences.
list_sentences = [ 'the movie is terrible',
                   'The Film was great.',
                   'It was just awful.']

# Or list of list of tokens.
list_sentences = [['the', 'movie', 'is', 'terrible'],
                  ['The', 'Film', 'was', 'great.'],
                  ['It', 'was', 'just', 'awful.']]

I created the following custom dataset:

import torch
from torch.utils.data import DataLoader, Dataset

class MyDataset(torch.utils.data.Dataset):

    def __init__(self, sentences, labels):

        self.sentences = sentences
        self.labels = labels

    def __getitem__(self, i):
        result = {}
        result['sentences'] = self.sentences[i]
        result['label'] = self.labels[i]
        return result

    def __len__(self):
        return len(self.labels)

When I provide input in the form of a list of sentences, the dataloader correctly returns batches of complete sentences. Note that batch_size=2:

list_sentences = [ 'the movie is terrible', 'The Film was great.', 'It was just awful.']
list_labels = [ 0, 1, 0]


dataset = MyDataset(list_sentences, list_labels)
dataloader = DataLoader(dataset, batch_size=2)

batch = next(iter(dataloader))
print(batch)
# {'sentences': ['the movie is terrible', 'The Film was great.'], <-- Great! 2 sentences in batch!
#  'label': tensor([0, 1])}

The batch correctly contains two sentences and two labels because batch_size=2.

However, when I instead enter the sentences as pre-tokenized list of list of token, I get weird results:

list_sentences = [['the', 'movie', 'is', 'terrible'], ['The', 'Film', 'was', 'great.'], ['It', 'was', 'just', 'awful.']]
list_labels = [ 0, 1, 0]


dataset = MyDataset(list_sentences, list_labels)
dataloader = DataLoader(dataset, batch_size=2)

batch = next(iter(dataloader))
print(batch)
# {'sentences': [('the', 'The'), ('movie', 'Film'), ('is', 'was'), ('terrible', 'great.')], <-- WHAT?
#  'label': tensor([0, 1])}

Note that this batch's sentences is one single list with tuples of word pairs. I was expecting sentences to be a list of two lists, like this:

{'sentences': [['the', 'movie', 'is', 'terrible'], ['The', 'Film', 'was', 'great.']

What is going on?


Solution

  • This behavior is because the default collate_fn does the following when it has to collate lists (which is the case for ['sentences']):

    # [...]
    elif isinstance(elem, container_abcs.Sequence):
        # check to make sure that the elements in batch have consistent size
        it = iter(batch)
        elem_size = len(next(it))
        if not all(len(elem) == elem_size for elem in it):
            raise RuntimeError('each element in list of batch should be of equal size')
        transposed = zip(*batch)
        return [default_collate(samples) for samples in transposed]
    

    The "problem" happens because, in the last two lines, it'll recursively call zip(*batch) while the batch is a container_abcs.Sequence (and list is), and zip behaves like this.

    As you can see:

    batch = [['the', 'movie', 'is', 'terrible'], ['The', 'Film', 'was', 'great.']]
    list(zip(*batch))
    
    # [('the', 'The'), ('movie', 'Film'), ('is', 'was'), ('terrible', 'great.')]
    

    I don't see a workaround in your case, except implementing a new collator and passing it to the DataLoader(..., collate_fn=mycollator). For instance, a simple ugly one could be:

    def mycollator(batch):
        assert all('sentences' in x for x in batch)
        assert all('label' in x for x in batch)
        return {
            'sentences': [x['sentences'] for x in batch],
            'label': torch.tensor([x['label'] for x in batch])
        }