pythonrandompytorchprobabilitysampling# PyTorch: How to sample from a tensor where each value in the tensor has a different likelihood of being selected?

Given tensor
`A = torch.tensor([0.0316, 0.2338, 0.2338, 0.2338, 0.0316, 0.0316, 0.0860, 0.0316, 0.0860])`

containing probabilities which sum to 1 (I removed some decimals but it's safe to assume it'll always sum to 1), I want to sample a value from `A`

where the value itself is the likelihood of getting sampled. For instance, the likelihood of sampling `0.0316`

from `A`

is `0.0316`

. The output of the value sampled should still be a tensor.

I tried using `WeightedRandomSampler`

but it doesn't allow the value selected to be a tensor anymore, instead it detaches.

One caveat that makes this tricky is that I want to also know the index of the sampled value as it appears in the tensor. That is, say I sample `0.2338`

, I want to know if it's index `1`

, `2`

or `3`

of tensor `A`

.

Solution

Selecting with the expected probabilities can be achieved by accumulating the weights and selecting the insertion index of a random float [0,1). The example array *A* is slightly adjusted to sum up to 1.

```
import torch
A = torch.tensor([0.0316, 0.2338, 0.2338, 0.2338, 0.0316, 0.0316, 0.0860, 0.0316, 0.0862], requires_grad=True)
p = A.cumsum(0)
#tensor([0.0316, 0.2654, 0.4992, 0.7330, 0.7646, 0.7962, 0.8822, 0.9138, 1.0000], grad_fn=<CumsumBackward0>))
idx = torch.searchsorted(p, torch.rand(1))
A[idx], idx
```

Output

```
(tensor([0.2338], grad_fn=<IndexBackward0>), tensor([3]))
```

This is faster than the more common approach with `A.multinomial(1)`

.

Sampling 10000 times one element to check that the distribution conforms to the probabilities

```
from collections import Counter
Counter(int(A.multinomial(1)) for _ in range(10000))
#1 loop, best of 5: 233 ms per loop
# vs @HatemAli's solution
dist=torch.distributions.categorical.Categorical(probs=A)
Counter(int(dist.sample()) for _ in range(10000))
# 10 loops, best of 5: 107 ms per loop
Counter(int(torch.searchsorted(p, torch.rand(1))) for _ in range(10000))
# 10 loops, best of 5: 53.2 ms per loop
```

Output

```
Counter({0: 319,
1: 2360,
2: 2321,
3: 2319,
4: 330,
5: 299,
6: 903,
7: 298,
8: 851})
```

- Python Jinja2 LaTeX Table
- Getting attributes of a class
- How can I print many significant figures in Python?
- How to allow list append() method to return the new list
- Calculate Last Friday of Month in Pandas
- Python type hint for Iterable[str] that isn't str
- How to iterate over a list in chunks
- How to exit the entire application from a Python thread?
- Running shell command and capturing the output
- How do I pass a variable by reference?
- Convert range(r) to list of strings of length 2 in python
- How can I get the start and end dates for each week?
- how to use send_message() in python-telegram-bot
- Python conditional replacement based on element type
- How can I count the number of items in an arbitrary iterable (such as a generator)?
- Find longest consecutive range of numbers in list
- Insert text in braces with asyncpg
- How does one put a link / url to the web-site's home page in Django?
- How to determine if a path is a subdirectory of another?
- Custom Keybindings for Ipython terminal
- FastAPI asynchronous background tasks blocks other requests?
- How to make sure that information from one file is duplicated into several text documents, without specific lines
- Installing a Python environment with Anaconda
- sklearn pipeline model predicting same results for all input
- Brew command not found after installing Anaconda Python
- How to get an XPath from selenium webelement or from lxml?
- Pipe PuTTY console to Python script
- How to align the axes of a figure in matplotlib?
- Persist ParentDocumentRetriever of langchain
- How to reset index in a pandas dataframe?