Search code examples
pythonindexingpytorchslice

Indexing and slicing in pytorch


I have a batch of 2d tensors of dimension [B,n,n], and a tensor of coordinates of dimension [B,2].

a = torch.arange(48).reshape((3,4,4))
coords = torch.tensor([[0,1],[1,2],[1,3]],dtype=int)

The code that gives me the result i want is:

a[torch.arange(3),coords[:,0],coords[:,1]]

I just don't understand why i cannot just use the following as i thought ':' meant taking all indexes:

a[:,coords[:,0],coords[:,1]]

What am i missing here?


Solution

  • You can think of the difference as parallel indexing vs selecting by combinations.

    • In the former, you use an arrangement of the same length as your dim=1 and dim=2 tensor indexers to signify what to select on each batch element. If you detail the indexing happening with a[torch.arange(3),coords[:,0],coords[:,1]], you are doing:

      >>> a[[0, 1, 2], [0, 1, 1], [1, 2, 3])
      tensor([ 1, 22, 39])
      

      Therefore, you will what I refer to as in "parallel", which means taking first a[0,0,1], then a[1,1,2], and finally a[2,1,3]. This would correspond to individual values being selected and stacked:

      >>> torch.stack([a[0,0,1], a[1,1,2], a[2,1,3]])
      tensor([ 1, 22, 39])
      
    • In the latter, you perform combinations since : refers to a "select all" on dim=0. Here again, if we detail the indexing, we have:

      • a[:,0,1] ie. a[[0,1,2],0,1], yielding tensor([ 1, 17, 33]).
      • a[:,1,2] ie. a[[0,1,2],1,2], yielding tensor([ 6, 22, 38])
      • a[:,1,3] ie. a[[0,1,2],1,3], yielding tensor([ 7, 23, 39]).

      This is performed column-wise and not batch-wise. For all batch elements, we first take all (0,1)s, then all (1,2)s, and finally all (1,3)s. The corresponding operation can be performed with:

      torch.dstack([a[:,0,1], a[:,1,2], a[:,1,3]])
      tensor([[[ 1,  6,  7],
               [17, 22, 23],
               [33, 38, 39]]])