Search code examples
pythondeep-learningpytorch

Do I have to pass the dataset both to the loader and the RandomSampler?


This is my code

  def train_dataloader(self):
        if self._is_weighted_sampler:
            weights = list(self.label_weight_by_name.values())
            sampler = torch.utils.data.sampler.WeightedRandomSampler(
                torch.tensor(weights), len(weights)
            )
        else:
            sampler = torch.utils.data.RandomSampler(self._train_dataset)
        return DataLoader(self._train_dataset, batch_size=self._batch_size, shuffle=True, sampler=sampler)

Notice in the case of Weighted sampler, it doesn't require the dataset, but RandomSampler does.

In the RandomSampler case, it means the dataset is passed twice.

I must be missing something about how this is to be used, please correct me.


Solution

  • Actually, it looks like you're correct about this discrepancy; there doesn't seem to be an immediately obvious reason why one call needs the dataset object and the other does not. Per the docs the function prototypes are:

    torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None, generator=None)
    

    Digging into the source code, one sees that data_source is never indexed by Random_Sampler and is only ever used as len(data_source). This object yields indices, and the dataset object is only used to determine the length of the data.

    torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True, generator=None)
    

    Randomly samples weights from weights (which i.e. should be set up to have the same number of elements as your dataset) and returns a set of indices, which then must be separately used to index into the dataset object.

    Perhaps the rationale of the developers was "in the case of a WeightedRandomSampler, the user must define a weight for each item in the data source. In the case of RandomSampler, all of the weights are one, so rather than defining a vector of ones the user can simply pass the dataset object itself." Why they wouldn't just pass the length of the dataset as an integer to RandomSampler is beyond me.