Search code examples

Understanding Pytorch Tensor Slicing

Let a and b be two PyTorch tensors with a.shape=[A,3] and b.shape=[B,3]. Further b is of type long.

Then I know there are several ways slicing a. For example,

c = a[N1:N2:jump,[0,2]] # N1<N2<A

would return c.shape = [2,2] for N1=1 and N2=4 and jump=2.

But the below should have thrown a error,

c = a[b]

but instead c.shape = [B,3,3].

For example,

a = torch.rand(10,3)
b = torch.rand(20,3).long()
print(a[b].shape) #torch.Size([20, 3, 3])

Can someone explain how the slicing is working for a[b]?


  • Basics

    • When you use a[b], PyTorch is performing advanced indexing.
    • In this case, each row of the tensor b is treated as an index into the first dimension of a, and the corresponding rows of a are returned.
    • Since b has shape [B,3], this means that each row of b is a 3-element index into the first dimension of a. So the result of a[b] will have shape [B,3,d], where d is the number of columns in a.

    For example

    suppose that b has the following values:

    b = torch.tensor([[0,1,2], [3,4,5], [1,2,3]])
    • Then the result of a[b] will be a tensor with shape [3,3,3], where the first dimension corresponds to the three rows of b and the second dimension corresponds to the three indices in each row of b. The third dimension corresponds to the three columns of a.

    Here's how the values are computed:

    • The first row of b is [0,1,2].
    • This means that the first row of a is returned,
    • followed by the second row of a, and then the third row of a.
    • So the first "slice" of the result will be:
    [[a[0,0], a[0,1], a[0,2]],
     [a[1,0], a[1,1], a[1,2]],
     [a[2,0], a[2,1], a[2,2]]]

    The second row of b is [3,4,5].

    • This means that the fourth row of a is returned,
    • followed by the fifth row of a,
    • and then the sixth row of a.
    • So the second "slice" of the result will be:
    [[a[3,0], a[3,1], a[3,2]],
     [a[4,0], a[4,1], a[4,2]],
     [a[5,0], a[5,1], a[5,2]]]

    The third row of b is [1,2,3].

    • This means that the second row of a is returned,
    • followed by the third row of a,
    • and then the fourth row of a.
    • So the third "slice" of the result will be:
    [[a[1,0], a[1,1], a[1,2]],
     [a[2,0], a[2,1], a[2,2]],
     [a[3,0], a[3,1], a[3,2]]]

    All of these slices are concatenated along the first dimension to produce the final result with shape [3,3,3].