I have the following torch tensor:
tensor([[-0.2, 0.3],
[-0.5, 0.1],
[-0.4, 0.2]])
and the following numpy array: (I can convert it to something else if necessary)
[1 0 1]
I want to get the following tensor:
tensor([0.3, -0.5, 0.2])
i.e. I want the numpy array to index each sub-element of my tensor. Preferably without using a loop.
Thanks in advance
You may want to use torch.gather
- "Gathers values along an axis specified by dim."
t = torch.tensor([[-0.2, 0.3],
[-0.5, 0.1],
[-0.4, 0.2]])
idxs = np.array([1,0,1])
idxs = torch.from_numpy(idxs).long().unsqueeze(1)
# or torch.from_numpy(idxs).long().view(-1,1)
t.gather(1, idxs)
tensor([[ 0.3000],
[ 0.2000]])
Here, your index is numpy array so you have to convert it to LongTensor.