I am converting someone else's code into a neater torch-y pipeline, using datasets and dataloaders, collate functions and samplers. While I have done such work before, I am not sure how to tackle the following problem.
The dataset contains sentences as samples. Every samples therefore has a number of words (or tokens
), which we can get by naively splitting the sample on white space (sample.split()
). Such a dummy dataset can look like this:
from random import randint
from torch.utils.data import Dataset
class DummyDataset(Dataset):
def __init__(self):
data = []
for _ in range(128):
data.append("hello " * randint(64, 176))
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx: int):
return self.data[idx]
Now I want to be able to load data so that the max. number of tokens in a batch is not more than 250. That implies that the batch size can differ between iterations. One batch may contain two samples that have no more than 250 tokens in total (for instance 127 + 77) and another can have three (66+66+66). Now, the core functionality for this is rather straightforward. Full example below; not optimized by sorting on length or something but that's okay for this example.
The question is, how can I integrate this in the PyTorch eco-system? Batch sizes are so often used to indicate the number of samples
(like in the dataloader). So where should I plug this in, or what should I subclass, to make this work like a regular dataloader?
from random import randint
from torch.utils.data import Dataset
class DummyDataset(Dataset):
def __init__(self):
data = []
for _ in range(128):
data.append("hello " * randint(64, 176))
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx: int):
return self.data[idx]
if __name__ == '__main__':
dataset = DummyDataset()
def get_batch(max_tokens: int = 250):
data_idxs = list(range(len(dataset)))
batch = []
total_batch_len = 0
while data_idxs:
sample = dataset[data_idxs[0]]
sample_len = len(sample.split())
if total_batch_len + sample_len <= max_tokens:
batch.append(sample)
total_batch_len += sample_len
data_idxs.pop(0)
elif batch:
yield batch
batch = []
total_batch_len = 0
yield batch
# Sanity check that we indeed get all items from the dataset
num_samples = 0
num_batches = 0
for b in get_batch():
num_samples += len(b)
num_batches += 1
print(f"Created {num_batches} batches")
assert num_samples == len(dataset)
Maybe torchtext's Iterator and its batch_size_fn
can help but I have no experience with it (where should I add it; is it a dataloader itself or should I still wrap a dataloader around it, etc.).
After reading some source code, it seems that you can just use any iterator in a Dataloader's batch_sampler
. So the following works as expected.
from random import randint
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader
class DummyDataset(Dataset):
def __init__(self):
data = []
for _ in range(128):
data.append("hello " * randint(64, 176))
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx: int):
return self.data[idx]
class TokenBatchSampler:
def __init__(self, max_tokens: int = 250):
self.max_tokens = max_tokens
self.batches = []
self._prepare_dataset()
def __len__(self) -> int:
return len(self.batches)
def __iter__(self):
return iter(self.batches)
def _prepare_dataset(self):
data_idxs = list(range(len(dataset)))
batches = []
batch_idxs = []
total_batch_len = 0
while data_idxs:
sample_idx = data_idxs[0]
sample = dataset[sample_idx]
sample_len = len(sample.split())
if total_batch_len + sample_len <= self.max_tokens:
batch_idxs.append(sample_idx)
total_batch_len += sample_len
data_idxs.pop(0)
elif batch_idxs:
batches.append(batch_idxs)
batch_idxs = []
total_batch_len = 0
batches.append(batch_idxs)
self.batches = batches
if __name__ == "__main__":
dataset = DummyDataset()
sampler = TokenBatchSampler()
dataloader = DataLoader(dataset, batch_sampler=sampler)
# Sanity check that we indeed get all items from the dataset
for epoch in range(3):
num_samples = 0
num_batches = 0
for b in dataloader:
num_samples += len(b)
num_batches += 1
print(f"Created {num_batches} batches in epoch {epoch}")
assert num_samples == len(dataset)
print(f"DataLoader length {len(dataloader)}")