I have a torch.utils.data.Dataset
object, I would like to have a DataLoader
or a similar object that accepts a list of idxs and returns a batch of samples with the corresponding idxs.
Example, I have
list_idxs = [10, 109, 7, 12]
I would like to do like:
batch = loader.getbatch(list_idxs)
where batch contains:
[sample10, sample109, sample7, sample12]
Is there a simple and elegant way to do that in an optimized way?
If I understand your question correctly, you could have a DataLoader
return a sequence of hand-selected batches using a custom batch_sampler
(you don't even need to pass it a sampler
in this case).
Given an arbitrary Dataset
:
>>> from torch.utils.data import DataLoader, Dataset
>>> from torch.utils.data.sampler import Sampler
>>> class MyDataset(Dataset):
... def __getitem__(self, idx):
... return idx
you can then define something like:
>>> class MyBatchSampler(Sampler):
... def __init__(self, batches):
... self.batches = batches
...
... def __iter__(self):
... for batch in self.batches:
... yield batch
...
... def __len__(self):
... return len(self.batches)
which just takes a list of lists containing dataset indices to include in each batch.
Then:
>>> dataset = MyDataset()
>>> batch_sampler = MyBatchSampler([[1, 2, 3], [5, 6, 7], [4, 2, 1]])
>>> dataloader = DataLoader(dataset=dataset, batch_sampler=batch_sampler)
>>> for batch in dataloader:
... print(batch)
...
tensor([1, 2, 3])
tensor([5, 6, 7])
tensor([4, 2, 1])
Should be easy to extend to your actual Dataset, etc.