Say I have a pytorch tensor tensor([3,5,7,3,9,3,0]). I'd like to extract the indices where 3 appears, i.e. tensor([0,3,5]). Is there a built-in function for this?
tensor([3,5,7,3,9,3,0])
3
tensor([0,3,5])
There is a dedicated function for this:
torch.where(my_tensor == the_number)