I have a batch of 2d tensors of dimension [B,n,n], and a tensor of coordinates of dimension [B,2].
a = torch.arange(48).reshape((3,4,4))
coords = torch.tensor([[0,1],[1,2],[1,3]],dtype=int)
The code that gives me the result i want is:
a[torch.arange(3),coords[:,0],coords[:,1]]
I just don't understand why i cannot just use the following as i thought ':' meant taking all indexes:
a[:,coords[:,0],coords[:,1]]
What am i missing here?
You can think of the difference as parallel indexing vs selecting by combinations.
In the former, you use an arrangement of the same length as your dim=1
and dim=2
tensor indexers to signify what to select on each batch element. If you detail the indexing happening with a[torch.arange(3),coords[:,0],coords[:,1]]
, you are doing:
>>> a[[0, 1, 2], [0, 1, 1], [1, 2, 3])
tensor([ 1, 22, 39])
Therefore, you will what I refer to as in "parallel", which means taking first a[0,0,1]
, then a[1,1,2]
, and finally a[2,1,3]
. This would correspond to individual values being selected and stacked:
>>> torch.stack([a[0,0,1], a[1,1,2], a[2,1,3]])
tensor([ 1, 22, 39])
In the latter, you perform combinations since :
refers to a "select all" on dim=0
. Here again, if we detail the indexing, we have:
a[:,0,1]
ie. a[[0,1,2],0,1]
, yielding tensor([ 1, 17, 33])
.a[:,1,2]
ie. a[[0,1,2],1,2]
, yielding tensor([ 6, 22, 38])
a[:,1,3]
ie. a[[0,1,2],1,3]
, yielding tensor([ 7, 23, 39])
.This is performed column-wise and not batch-wise. For all batch elements, we first take all (0,1)
s, then all (1,2)
s, and finally all (1,3)
s. The corresponding operation can be performed with:
torch.dstack([a[:,0,1], a[:,1,2], a[:,1,3]])
tensor([[[ 1, 6, 7],
[17, 22, 23],
[33, 38, 39]]])