pythonrandompytorchprobabilitysampling

# PyTorch: How to sample from a tensor where each value in the tensor has a different likelihood of being selected?

Given tensor `A = torch.tensor([0.0316, 0.2338, 0.2338, 0.2338, 0.0316, 0.0316, 0.0860, 0.0316, 0.0860])` containing probabilities which sum to 1 (I removed some decimals but it's safe to assume it'll always sum to 1), I want to sample a value from `A` where the value itself is the likelihood of getting sampled. For instance, the likelihood of sampling `0.0316` from `A` is `0.0316`. The output of the value sampled should still be a tensor.

I tried using `WeightedRandomSampler` but it doesn't allow the value selected to be a tensor anymore, instead it detaches.

One caveat that makes this tricky is that I want to also know the index of the sampled value as it appears in the tensor. That is, say I sample `0.2338`, I want to know if it's index `1`, `2` or `3` of tensor `A`.

Solution

• Selecting with the expected probabilities can be achieved by accumulating the weights and selecting the insertion index of a random float [0,1). The example array A is slightly adjusted to sum up to 1.

``````import torch

A = torch.tensor([0.0316, 0.2338, 0.2338, 0.2338, 0.0316, 0.0316, 0.0860, 0.0316, 0.0862], requires_grad=True)

p = A.cumsum(0)
#tensor([0.0316, 0.2654, 0.4992, 0.7330, 0.7646, 0.7962, 0.8822, 0.9138, 1.0000], grad_fn=<CumsumBackward0>))

idx = torch.searchsorted(p, torch.rand(1))
A[idx], idx
``````

Output

``````(tensor([0.2338], grad_fn=<IndexBackward0>), tensor([3]))
``````

This is faster than the more common approach with `A.multinomial(1)`.
Sampling 10000 times one element to check that the distribution conforms to the probabilities

``````from collections import Counter

Counter(int(A.multinomial(1)) for _ in range(10000))
#1 loop, best of 5: 233 ms per loop

# vs @HatemAli's solution
dist=torch.distributions.categorical.Categorical(probs=A)
Counter(int(dist.sample()) for _ in range(10000))
# 10 loops, best of 5: 107 ms per loop

Counter(int(torch.searchsorted(p, torch.rand(1))) for _ in range(10000))
# 10 loops, best of 5: 53.2 ms per loop
``````

Output

``````Counter({0: 319,
1: 2360,
2: 2321,
3: 2319,
4: 330,
5: 299,
6: 903,
7: 298,
8: 851})
``````