Search code examples
pythonpytorch

Pytorch random choose an index with condition


I have a tensor that stores whether or not an index in available

available = torch.Tensor([1,1,0,0,1,0])

and I want to return an index of either 0, 1, or 4 given that available[0],available[1],available[4] all equal to 1 with the same possibility.

can somebody help me with this? Thanks


Solution

  • Torch makes this easy. You can use multinomial as per this answer:

    num_samples = 1
    available.multinomial(num_samples, replacement=False)
    

    Here, num_samples indicates how many samples you'd like to draw.

    Because you have 1s and 0s already, your available tensor naturally gives the correct weights for the multinomial function.

    If you are going to draw more than 3 samples, this will error unless you change replacement to True.