Search code examples

How to make `__getitems__` return a dict?

In torch's Dataset, on top of the obligatory __getitem__ method, you can implement the __getitems__ method.

In my case __getitem__ returns a dict, but I can't figure out how to do the same with __getitems__.

class StackOverflowDataset(
    def __init__(self, data):
        self._data = data

    def __getitem__(self, idx):
        return {'item': self._data[idx], 'whatever': idx*self._data[idx]+3}

    def __getitems__(self, idxs):
        return {'item': self._data[idxs], 'whatever': idxs*self._data[idxs]+3}
    def __len__(self):
        return len(self._data)

dataset = StackOverflowDataset(np.random.random(5))
for X in DataLoader(dataset, 2):

If I comment out __getitems__ it works, but leaving it there raises a KeyError: 0.

KeyError                                  Traceback (most recent call last)
Cell In[182], line 15
     12         return len(self._data)
     14 dataset = StackOverflowDataset(np.random.random(5))
---> 15 for X in DataLoader(dataset, 2):
     16     print(X)
     17     break

File ~/recommenders/venv/lib/python3.12/site-packages/torch/utils/data/, in _BaseDataLoaderIter.__next__(self)
    627 if self._sampler_iter is None:
    628     # TODO(
    629     self._reset()  # type: ignore[call-arg]
--> 630 data = self._next_data()
    631 self._num_yielded += 1
    632 if self._dataset_kind == _DatasetKind.Iterable and \
    633         self._IterableDataset_len_called is not None and \
    634         self._num_yielded > self._IterableDataset_len_called:

File ~/recommenders/venv/lib/python3.12/site-packages/torch/utils/data/, in _SingleProcessDataLoaderIter._next_data(self)
    671 def _next_data(self):
    672     index = self._next_index()  # may raise StopIteration
--> 673     data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    674     if self._pin_memory:
    675         data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)

File ~/recommenders/venv/lib/python3.12/site-packages/torch/utils/data/_utils/, in _MapDatasetFetcher.fetch(self, possibly_batched_index)
     53 else:
     54     data = self.dataset[possibly_batched_index]
---> 55 return self.collate_fn(data)

File ~/recommenders/venv/lib/python3.12/site-packages/torch/utils/data/_utils/, in default_collate(batch)
    256 def default_collate(batch):
    257     r"""
    258     Take in a batch of data and put the elements within the batch into a tensor with an additional outer dimension - batch size.
    315         >>> default_collate(batch)  # Handle `CustomType` automatically
    316     """
--> 317     return collate(batch, collate_fn_map=default_collate_fn_map)

File ~/recommenders/venv/lib/python3.12/site-packages/torch/utils/data/_utils/, in collate(batch, collate_fn_map)
    109 def collate(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None):
    110     r"""
    111     General collate function that handles collection type of element within each batch.
    135         for the dictionary of collate functions as `collate_fn_map`.
    136     """
--> 137     elem = batch[0]
    138     elem_type = type(elem)
    140     if collate_fn_map is not None:

KeyError: 0


  • That's because pytorch tries to access data by index, starting from 0. Official documentation says:

    Subclasses could also optionally implement getitems(), for speedup batched samples loading. This method accepts list of indices of samples of batch and returns list of samples.

    In other words, __getitems__ should return list, not a dict.