pythonpytorchpytorch-dataloader# Using Pytorch Dataloader with Probability Distribution

TL;DR: I want to use `DataLoader`

to take a **weighted** random sample of the available rows. How do?

I've put together some python code that fits a certain kind of input-driven dynamical system to data using batched gradient descent over the parameters that define the model. I have the following snippet of Python code that gets the job done using Pytorch.

```
k_trn = self.linear.k_gen(in_trn,t)
u_trn = torch.tensor(in_trn.T)
x_trn = torch.tensor(out_trn.T, dtype = torch.float)
data = TensorDataset(u_trn[:-1,:],k_trn[:-1,:],x_trn[1:,:])
loader = DataLoader(data, batch_size = 20, shuffle = True)
```

Data types:

`u_trn:`

N x 1 tensor (pytorch's array)`k_trn:`

N x K tensor`x_trn:`

N x n tensor

The rows of `u_trn,k_trn,x_trn`

correspond to three trajectories (with u corresponding to the "input"). Each time I iterate over the loader (which can be done, e.g. with a loop `for u,k,x in loader:`

), I get a batch of 20 rows from u_trn, 20 rows of k_trn, and 20 rows of x_trn. These rows are selected with a uniform probability, without replacement.

The catch is that I would like to sample these rows with a **non-uniform** probability. In particular: denote S = (1/1 + 1/2 + ... + 1/N). I would like for the loader to select the jth row with probability 1/(S*j).

After looking at the relevant doumentation, I suspect that this can be done by messing with either the `sampler`

or `batch_sampler`

keyword arguments when initializing the `DataLoader`

object, but I'm having trouble parsing the documentation well enough to implement the behavior that I'm looking for.

I'd appreciate any help with this. I've tried to keep my question brief; please let me know if I've left out any relevant information.

Followup: with the help of Shai's answer, I've gotten things to work properly. Here's a quick script that I used to test this out and make sure that everything was working as expected.

```
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler
import matplotlib.pyplot as plt
N = 100
x = np.zeros((N,2))
x[:,0] = 1 + np.arange(N)
data = TensorDataset(torch.Tensor(x))
weights = [1/j for j in range(1, N+1)] # my weights
sampler = WeightedRandomSampler(weights, 10000, replacement=True)
loader = DataLoader(data, batch_size=20, sampler=sampler)
sums = []
for y, in loader:
for k in range(len(y)):
sums.append(np.sum(y[k].numpy()))
h = plt.hist(sums, bins = N)
a = h[0][0]
plt.plot([a/(n+1) for n in range(N)], lw = 3)
```

And the resulting plot:

Note that weights are automatically normalized, so there is no need to divide by the sum S. Note also that there is no need for `shuffle=True`

in the loader; the sampler takes care of the randomization on its own.

Solution

Why don't you simply use `WeightedRandomSampler`

?

```
weights = [1./(S*j) for j in range(1, N+1)] # your weights
sampler = WeightedRandomSampler(weights, replacement=True)
loader = DataLoader(data, batch_size=20, sampler=sampler)
```

- Python Jinja2 LaTeX Table
- Getting attributes of a class
- How can I print many significant figures in Python?
- How to allow list append() method to return the new list
- Calculate Last Friday of Month in Pandas
- Python type hint for Iterable[str] that isn't str
- How to iterate over a list in chunks
- How to exit the entire application from a Python thread?
- Running shell command and capturing the output
- How do I pass a variable by reference?
- Convert range(r) to list of strings of length 2 in python
- How can I get the start and end dates for each week?
- how to use send_message() in python-telegram-bot
- Python conditional replacement based on element type
- How can I count the number of items in an arbitrary iterable (such as a generator)?
- Find longest consecutive range of numbers in list
- Insert text in braces with asyncpg
- How does one put a link / url to the web-site's home page in Django?
- How to determine if a path is a subdirectory of another?
- Custom Keybindings for Ipython terminal
- FastAPI asynchronous background tasks blocks other requests?
- How to make sure that information from one file is duplicated into several text documents, without specific lines
- Installing a Python environment with Anaconda
- sklearn pipeline model predicting same results for all input
- Brew command not found after installing Anaconda Python
- How to get an XPath from selenium webelement or from lxml?
- Pipe PuTTY console to Python script
- How to align the axes of a figure in matplotlib?
- Persist ParentDocumentRetriever of langchain
- How to reset index in a pandas dataframe?