Let a
and b
be two PyTorch tensors with a.shape=[A,3]
and b.shape=[B,3]
. Further b
is of type long
.
Then I know there are several ways slicing a
. For example,
c = a[N1:N2:jump,[0,2]] # N1<N2<A
would return c.shape = [2,2]
for N1=1 and N2=4 and jump=2.
But the below should have thrown a error,
c = a[b]
but instead c.shape = [B,3,3]
.
For example,
a = torch.rand(10,3)
b = torch.rand(20,3).long()
print(a[b].shape) #torch.Size([20, 3, 3])
Can someone explain how the slicing is working for a[b]
?
Basics
- When you use a[b], PyTorch is performing advanced indexing.
- In this case, each row of the tensor b is treated as an index into the first dimension of a, and the corresponding rows of a are returned.
- Since b has shape [B,3], this means that each row of b is a 3-element index into the first dimension of a. So the result of a[b] will have shape [B,3,d], where d is the number of columns in a.
For example
suppose that b has the following values:
b = torch.tensor([[0,1,2], [3,4,5], [1,2,3]])
- Then the result of a[b] will be a tensor with shape [3,3,3], where the first dimension corresponds to the three rows of b and the second dimension corresponds to the three indices in each row of b. The third dimension corresponds to the three columns of a.
Here's how the values are computed:
- The first row of b is [0,1,2].
- This means that the first row of a is returned,
- followed by the second row of a, and then the third row of a.
- So the first "slice" of the result will be:
[[a[0,0], a[0,1], a[0,2]],
[a[1,0], a[1,1], a[1,2]],
[a[2,0], a[2,1], a[2,2]]]
The second row of b is [3,4,5].
- This means that the fourth row of a is returned,
- followed by the fifth row of a,
- and then the sixth row of a.
- So the second "slice" of the result will be:
[[a[3,0], a[3,1], a[3,2]],
[a[4,0], a[4,1], a[4,2]],
[a[5,0], a[5,1], a[5,2]]]
The third row of b is [1,2,3].
- This means that the second row of a is returned,
- followed by the third row of a,
- and then the fourth row of a.
- So the third "slice" of the result will be:
[[a[1,0], a[1,1], a[1,2]],
[a[2,0], a[2,1], a[2,2]],
[a[3,0], a[3,1], a[3,2]]]
All of these slices are concatenated along the first dimension to produce the final result with shape [3,3,3].