This is my code
def train_dataloader(self):
if self._is_weighted_sampler:
weights = list(self.label_weight_by_name.values())
sampler = torch.utils.data.sampler.WeightedRandomSampler(
torch.tensor(weights), len(weights)
)
else:
sampler = torch.utils.data.RandomSampler(self._train_dataset)
return DataLoader(self._train_dataset, batch_size=self._batch_size, shuffle=True, sampler=sampler)
Notice in the case of Weighted sampler, it doesn't require the dataset, but RandomSampler does.
In the RandomSampler case, it means the dataset is passed twice.
I must be missing something about how this is to be used, please correct me.
Actually, it looks like you're correct about this discrepancy; there doesn't seem to be an immediately obvious reason why one call needs the dataset object and the other does not. Per the docs the function prototypes are:
torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None, generator=None)
Digging into the source code, one sees that data_source
is never indexed by Random_Sampler
and is only ever used as len(data_source)
. This object yields indices, and the dataset object is only used to determine the length of the data.
torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True, generator=None)
Randomly samples weights from weights
(which i.e. should be set up to have the same number of elements as your dataset) and returns a set of indices, which then must be separately used to index into the dataset object.
Perhaps the rationale of the developers was "in the case of a WeightedRandomSampler
, the user must define a weight for each item in the data source. In the case of RandomSampler
, all of the weights are one, so rather than defining a vector of ones the user can simply pass the dataset object itself." Why they wouldn't just pass the length of the dataset as an integer to RandomSampler
is beyond me.