I am searching for a way to do some batchwise indexing for tensors.
If I have a variable Q of size 1000, I can get the elements I want by Q[index], where index is a vector of the wanted elements.
Now I would like to do the same for more dimensional tensors. So suppose Q is of shape n x m and I have a index matrix of shape n x p.
My goal is to get for each of the n rows the specific p elements out of the m elements.
But Q[index] is not working for this situation.
Do you have any thoughts how to handle this?
You can seem to be a simple application of torch.gather
which doesn't require any additional reshaping of the data or index tensor:
>>> Q = torch.rand(5, 4)
tensor([[0.8462, 0.3064, 0.2549, 0.2149],
[0.6801, 0.5483, 0.5522, 0.6852],
[0.1587, 0.4144, 0.8843, 0.6108],
[0.5265, 0.8269, 0.8417, 0.6623],
[0.8549, 0.6437, 0.4282, 0.2792]])
>>> index
tensor([[0, 1, 2],
[2, 3, 1],
[0, 1, 2],
[2, 2, 2],
[1, 1, 2]])
The following gather operation applied on dim=1
return a tensor out
, such that:
out[i, j] = Q[i, index[i,j]]
This is done with the following call of torch.Tensor.gather
on Q
:
>>> Q.gather(dim=1, index=index)
tensor([[0.8462, 0.3064, 0.2549],
[0.5522, 0.6852, 0.5483],
[0.1587, 0.4144, 0.8843],
[0.8417, 0.8417, 0.8417],
[0.6437, 0.6437, 0.4282]])