Search code examples
pythonindexingpytorch

pytorch indexing multi-dimension tensor with another multi-dimension tensor


In pytorch, I have a tensor A with shape [b, m, n] and another tensor B with shape [b, k]. I want to index A with B. So the result tensor should have a shape [b, k, n].

I tried to do some search but got no luck. torch.index_select or torch.take can only take 1d index tensor. torch.gather requires input tensor and index tensor to have same shapes.


Solution

  • What you are trying to do is get out such that:

    out[b][k][n] = A[i][B[b][k][n]][n]
    

    To use torch.gather, you indeed have to have the same number of dimensions. You can do so by expanding an extra singleton dimension on B to have a shape of (b, k, n).

    Here is a minimal example:

    A = torch.rand(b,m,n)
    B = torch.randint(0,m,(b,k))
    

    Expand B:

    >>> B_ = B[:,:,None].expand(-1,-1,A.size(-1))
    

    Gather values from A:

    >>> A.gather(1,B_)