Search code examples
pytorchpytorch-dataloaderhuggingface-datasets

PyTorch: Can I group batches by length?


I am working on an ASR project, where I use a model from HuggingFace (wav2vec2). My goal for now is to move the training process to PyTorch, so I am trying to recreate everything that HuggingFace’s Trainer() class offers.

One of these utilities is the ability to group batches by length and combine this with dynamic padding (via a data collator). To be honest however, I am not sure how to even begin this in PyTorch.

The inputs in my case are 1-D arrays that represent the raw waveform of a .wav file. So before training I need to ensure that arrays of similar size will be batched together. Do I need to create a custom Dataloader class and alter it, so that every time it gives me batch sizes of lengths as close as possible?

An idea I had, was to somehow sort the data from shortest to longest (or the opposite), and each time extract batch_size samples from them. This way, the first batch will consist of samples with the biggest lengths, the second batch will have the second biggest lengths, etc.

Nevertheless, I am not sure how to approach this implementation. Any advice will be greatly appreciated.

Thanks in advance.


Solution

  • One possible way of going about this is by using a batch sampler and implementing a collate_fn for your dataloader that will perform the dynamic padding on your batch elements.

    Take this basic dataset:

    class DS(Dataset):
        def __init__(self, files):
            super().__init__()
            self.len = len(files)
            self.files = files
    
        def __getitem__(self, index):
            return self.files[index]
    
        def __len__(self):
            return self.len
    

    Initialized with some random data:

    >>> file_len = np.random.randint(0, 100, (16*6))
    >>> files = [np.random.rand(s) for s in file_len]
    >>> ds = DS(files)
    

    Start by defining your batch sampler, this is essentially an iterable returning batches of indices to be used by the data loader to retrieve the elements from the dataset. As you explained we can just sort the lengths and construct the different batches from this sort:

    >>> batch_size = 16
    >>> batches = np.split(file_len.argsort()[::-1], batch_size)
    

    We should have elements that are close to each other in length.

    We can implement a collate_fn function to assemble the batch elements and integrate dynamic padding. This is basically putting an additional user-defined layer right between the dataset and the dataloader. The goal is to find the longest element in the batch and pad all other elements with the correct number of 0s:

    def collate_fn(batch):
        longest = max([len(x) for x in batch])
        s = np.stack([np.pad(x, (0, longest - len(x))) for x in batch])
        return torch.from_numpy(s)
    

    Then you can intialize a data loader:

    >>> dl = DataLoader(dataset=ds, batch_sampler=batches, collate_fn=collate_fn)
    

    And try iterating, as you can see we get batches of decreasing lengths:

    >>> for x in dl:
    ...   print(x.shape)
    torch.Size([6, 99])
    torch.Size([6, 93])
    torch.Size([6, 83])
    torch.Size([6, 76])
    torch.Size([6, 71])
    torch.Size([6, 66])
    torch.Size([6, 57])
    ...
    

    This method has some flaws though, for instance, the distribution of elements will always be the same. This means you will always get the same batches in the same order of appearance. This is because this method is based on the sorting of elements in the dataset based on their length, there is no variability in the creation of the batches. You can reduce this effect by shuffling the batches (e.g. by wrapping batches inside a RandomSampler). However, as I said, the batches' content will remain the same throughout the training which might lead to some problems.

    Do note the use of batch_sampler in your data loader is mutually exclusive options batch_size, shuffle, and sampler!