In pytorch, I have a tensor A with shape [b, m, n] and another tensor B with shape [b, k]. I want to index A with B. So the result tensor should have a shape [b, k, n].
I tried to do some search but got no luck. torch.index_select or torch.take can only take 1d index tensor. torch.gather requires input tensor and index tensor to have same shapes.
What you are trying to do is get out
such that:
out[b][k][n] = A[i][B[b][k][n]][n]
To use torch.gather
, you indeed have to have the same number of dimensions. You can do so by expanding an extra singleton dimension on B to have a shape of (b, k, n)
.
Here is a minimal example:
A = torch.rand(b,m,n)
B = torch.randint(0,m,(b,k))
Expand B
:
>>> B_ = B[:,:,None].expand(-1,-1,A.size(-1))
Gather values from A
:
>>> A.gather(1,B_)