Search code examples
pythonpytorchslicetensor

How to simplify 3D tensor slicing


I want to slice a 3D tensor in PyTorch. The shape of the 3D tensor src_tensor is (batch, max_len, hidden_dim), and I have a 1D index vector indices with the shape (batch,). I want to slice along the second dimension of src_tensor. I can achieve this functionality with the following codes:

import torch
nums = 30
l = [i for i in range(nums)]
src_tensor = torch.Tensor(l).reshape((3,5,2))
indices = [1,2,3]
slice_tensor = torch.zeros((3,2,2)) 
for i in range(3):
    p1,p2 = indices[i],indices[i]+1
    slice_tensor[i,:,:]=src_tensor[i,[p1,p2],:]
print(src_tensor)
print(indices)
print(slice_tensor)
"""
tensor([[[ 0.,  1.],
         [ 2.,  3.],
         [ 4.,  5.],
         [ 6.,  7.],
         [ 8.,  9.]],

        [[10., 11.],
         [12., 13.],
         [14., 15.],
         [16., 17.],
         [18., 19.]],

        [[20., 21.],
         [22., 23.],
         [24., 25.],
         [26., 27.],
         [28., 29.]]])
[1, 2, 3]
tensor([[[ 2.,  3.],
         [ 4.,  5.]],

        [[14., 15.],
         [16., 17.]],

        [[26., 27.],
         [28., 29.]]])
"""

My question is whether the above code can be simplified, for example, by eliminating the for loop.


Solution

  • Since your indexing along the second dimension depends on the batch index, I think you need to leverage some weird more complex indexing schemes :

    import torch
    # arbitrary source tensor
    src_tensor = torch.Tensor(range(30)).view((3,5,2))
    B, N, D = src_tensor.shape
    # arbitrary, non consecutive indices made into an index tensor for torch.gather
    ids = [0, 1, 3]
    index = torch.tensor(ids)
    # slicing
    sliced = torch.empty(B, 2, D) #empty instead of zero to avoid writing useless values
    sliced[:, 0, :] = src_tensor[range(B), ids]
    sliced[:, 1, :] = src_tensor[:,1:,:][range(B), ids]
    

    The first slicing on the last line serves to offset the tensor by 1 in the second dimension, which makes it possible to access the index+1 values without doing any addition or copying.

    Note that if you know for sure that your indices are all consecutives (like in your example), then using narrow instead of slicing will be even better. Alternatively, for some even more complex indexing schemes, you can also try torch.gather

    Note : as far as I'm aware, this slicing scheme is consistent with numpy's, so most of the documentation you can find for numpy slicing applies to torch