Search code examples
pythonrandompytorchprobabilitysampling

PyTorch: How to sample from a tensor where each value in the tensor has a different likelihood of being selected?


Given tensor A = torch.tensor([0.0316, 0.2338, 0.2338, 0.2338, 0.0316, 0.0316, 0.0860, 0.0316, 0.0860]) containing probabilities which sum to 1 (I removed some decimals but it's safe to assume it'll always sum to 1), I want to sample a value from A where the value itself is the likelihood of getting sampled. For instance, the likelihood of sampling 0.0316 from A is 0.0316. The output of the value sampled should still be a tensor.

I tried using WeightedRandomSampler but it doesn't allow the value selected to be a tensor anymore, instead it detaches.

One caveat that makes this tricky is that I want to also know the index of the sampled value as it appears in the tensor. That is, say I sample 0.2338, I want to know if it's index 1, 2 or 3 of tensor A.


Solution

  • Selecting with the expected probabilities can be achieved by accumulating the weights and selecting the insertion index of a random float [0,1). The example array A is slightly adjusted to sum up to 1.

    import torch
    
    A = torch.tensor([0.0316, 0.2338, 0.2338, 0.2338, 0.0316, 0.0316, 0.0860, 0.0316, 0.0862], requires_grad=True)
    
    p = A.cumsum(0)
    #tensor([0.0316, 0.2654, 0.4992, 0.7330, 0.7646, 0.7962, 0.8822, 0.9138, 1.0000], grad_fn=<CumsumBackward0>))
    
    idx = torch.searchsorted(p, torch.rand(1))
    A[idx], idx
    

    Output

    (tensor([0.2338], grad_fn=<IndexBackward0>), tensor([3]))
    

    This is faster than the more common approach with A.multinomial(1).
    Sampling 10000 times one element to check that the distribution conforms to the probabilities

    from collections import Counter
    
    Counter(int(A.multinomial(1)) for _ in range(10000))
    #1 loop, best of 5: 233 ms per loop
    
    # vs @HatemAli's solution
    dist=torch.distributions.categorical.Categorical(probs=A)
    Counter(int(dist.sample()) for _ in range(10000))
    # 10 loops, best of 5: 107 ms per loop
    
    Counter(int(torch.searchsorted(p, torch.rand(1))) for _ in range(10000))
    # 10 loops, best of 5: 53.2 ms per loop
    

    Output

    Counter({0: 319,
             1: 2360,
             2: 2321,
             3: 2319,
             4: 330,
             5: 299,
             6: 903,
             7: 298,
             8: 851})