Search code examples
pythonpytorchdatasetdataloader

How can I get a batch of samples from a dataset given a list of idxs in pytorch?


I have a torch.utils.data.Dataset object, I would like to have a DataLoader or a similar object that accepts a list of idxs and returns a batch of samples with the corresponding idxs.

Example, I have

list_idxs = [10, 109, 7, 12]

I would like to do like:

batch = loader.getbatch(list_idxs)

where batch contains:

[sample10, sample109, sample7, sample12]

Is there a simple and elegant way to do that in an optimized way?


Solution

  • If I understand your question correctly, you could have a DataLoader return a sequence of hand-selected batches using a custom batch_sampler (you don't even need to pass it a sampler in this case).

    Given an arbitrary Dataset:

    >>> from torch.utils.data import DataLoader, Dataset
    >>> from torch.utils.data.sampler import Sampler
    >>> class MyDataset(Dataset):
    ...     def __getitem__(self, idx):
    ...         return idx
    

    you can then define something like:

    >>> class MyBatchSampler(Sampler):
    ...     def __init__(self, batches):
    ...         self.batches = batches
    ...
    ...     def __iter__(self):
    ...         for batch in self.batches:
    ...             yield batch
    ...
    ...     def __len__(self):
    ...         return len(self.batches)
    

    which just takes a list of lists containing dataset indices to include in each batch.

    Then:

    >>> dataset = MyDataset()
    >>> batch_sampler = MyBatchSampler([[1, 2, 3], [5, 6, 7], [4, 2, 1]])
    >>> dataloader = DataLoader(dataset=dataset, batch_sampler=batch_sampler)
    >>> for batch in dataloader:
    ...     print(batch)
    ... 
    tensor([1, 2, 3])
    tensor([5, 6, 7])
    tensor([4, 2, 1])
    

    Should be easy to extend to your actual Dataset, etc.