Search code examples
pythonpytorchpytorch-dataloader

Using Pytorch Dataloader with Probability Distribution


TL;DR: I want to use DataLoader to take a weighted random sample of the available rows. How do?


I've put together some python code that fits a certain kind of input-driven dynamical system to data using batched gradient descent over the parameters that define the model. I have the following snippet of Python code that gets the job done using Pytorch.

k_trn = self.linear.k_gen(in_trn,t)
u_trn = torch.tensor(in_trn.T)
x_trn = torch.tensor(out_trn.T, dtype = torch.float)
data = TensorDataset(u_trn[:-1,:],k_trn[:-1,:],x_trn[1:,:])
loader = DataLoader(data, batch_size = 20, shuffle = True)

Data types:

  • u_trn: N x 1 tensor (pytorch's array)
  • k_trn: N x K tensor
  • x_trn: N x n tensor

The rows of u_trn,k_trn,x_trn correspond to three trajectories (with u corresponding to the "input"). Each time I iterate over the loader (which can be done, e.g. with a loop for u,k,x in loader:), I get a batch of 20 rows from u_trn, 20 rows of k_trn, and 20 rows of x_trn. These rows are selected with a uniform probability, without replacement.

The catch is that I would like to sample these rows with a non-uniform probability. In particular: denote S = (1/1 + 1/2 + ... + 1/N). I would like for the loader to select the jth row with probability 1/(S*j).

After looking at the relevant doumentation, I suspect that this can be done by messing with either the sampler or batch_sampler keyword arguments when initializing the DataLoader object, but I'm having trouble parsing the documentation well enough to implement the behavior that I'm looking for.

I'd appreciate any help with this. I've tried to keep my question brief; please let me know if I've left out any relevant information.


Followup: with the help of Shai's answer, I've gotten things to work properly. Here's a quick script that I used to test this out and make sure that everything was working as expected.

import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler
import matplotlib.pyplot as plt

N = 100
x = np.zeros((N,2))
x[:,0] = 1 + np.arange(N)
data = TensorDataset(torch.Tensor(x))

weights = [1/j for j in range(1, N+1)]  # my weights
sampler = WeightedRandomSampler(weights, 10000, replacement=True)
loader = DataLoader(data, batch_size=20, sampler=sampler)
sums = []

for y, in loader:
    for k in range(len(y)):
        sums.append(np.sum(y[k].numpy()))

h = plt.hist(sums, bins = N)
a = h[0][0]
plt.plot([a/(n+1) for n in range(N)], lw = 3)

And the resulting plot:

enter image description here

Note that weights are automatically normalized, so there is no need to divide by the sum S. Note also that there is no need for shuffle=True in the loader; the sampler takes care of the randomization on its own.


Solution

  • Why don't you simply use WeightedRandomSampler?

    weights = [1./(S*j) for j in range(1, N+1)]  # your weights
    sampler = WeightedRandomSampler(weights, replacement=True)
    loader = DataLoader(data, batch_size=20, sampler=sampler)